diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index c3dc6a2c4..0f65a551e 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -45,8 +45,7 @@ from config import ( AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_VOICE, - config_get, - config_set, + AppConfig, ) log = logging.getLogger(__name__) @@ -61,11 +60,11 @@ app.add_middleware( allow_headers=["*"], ) - -app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY -app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL -app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE +app.state.config = AppConfig() +app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY +app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL +app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), - "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), - "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -99,22 +98,17 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - config_set(app.state.OPENAI_API_BASE_URL, form_data.url) - config_set(app.state.OPENAI_API_KEY, form_data.key) - config_set(app.state.OPENAI_API_MODEL, form_data.model) - config_set(app.state.OPENAI_API_VOICE, form_data.speaker) - - app.state.OPENAI_API_BASE_URL.save() - app.state.OPENAI_API_KEY.save() - app.state.OPENAI_API_MODEL.save() - app.state.OPENAI_API_VOICE.save() + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key + app.state.config.OPENAI_API_MODEL = form_data.model + app.state.config.OPENAI_API_VOICE = form_data.speaker return { "status": True, - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), - "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), - "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, stream=True, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 8ebfb0446..1c309439d 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,8 +42,7 @@ from config import ( IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, - config_get, - config_set, + AppConfig, ) @@ -62,28 +61,30 @@ app.add_middleware( allow_headers=["*"], ) -app.state.ENGINE = IMAGE_GENERATION_ENGINE -app.state.ENABLED = ENABLE_IMAGE_GENERATION +app.state.config = AppConfig() -app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLED = ENABLE_IMAGE_GENERATION -app.state.MODEL = IMAGE_GENERATION_MODEL +app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.MODEL = IMAGE_GENERATION_MODEL -app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.IMAGE_SIZE = IMAGE_SIZE -app.state.IMAGE_STEPS = IMAGE_STEPS +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "engine": config_get(app.state.ENGINE), - "enabled": config_get(app.state.ENABLED), + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, } @@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.ENGINE, form_data.engine) - config_set(app.state.ENABLED, form_data.enabled) + app.state.config.ENGINE = form_data.engine + app.state.config.ENABLED = form_data.enabled return { - "engine": config_get(app.state.ENGINE), - "enabled": config_get(app.state.ENABLED), + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, } @@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel): @app.get("/url") async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), - "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, } @@ -121,29 +122,29 @@ async def update_engine_url( ): if form_data.AUTOMATIC1111_BASE_URL == None: - config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL)) + app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) - config_set(app.state.AUTOMATIC1111_BASE_URL, url) + app.state.config.AUTOMATIC1111_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.COMFYUI_BASE_URL == None: - config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL) + app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) - config_set(app.state.COMFYUI_BASE_URL, url) + app.state.config.COMFYUI_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { - "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), - "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "status": True, } @@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/openai/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -168,13 +169,13 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - config_set(app.state.OPENAI_API_BASE_URL, form_data.url) - config_set(app.state.OPENAI_API_KEY, form_data.key) + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key return { "status": True, - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel): @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)} + return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} @app.post("/size/update") @@ -193,9 +194,9 @@ async def update_image_size( ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): - config_set(app.state.IMAGE_SIZE, form_data.size) + app.state.config.IMAGE_SIZE = form_data.size return { - "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE), + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, "status": True, } else: @@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel): @app.get("/steps") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)} + return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} @app.post("/steps/update") @@ -219,9 +220,9 @@ async def update_image_size( form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) ): if form_data.steps >= 0: - config_set(app.state.IMAGE_STEPS, form_data.steps) + app.state.config.IMAGE_STEPS = form_data.steps return { - "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS), + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, "status": True, } else: @@ -234,14 +235,14 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": - r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") + r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() return list( @@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)): else: r = requests.get( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" ) models = r.json() return list( @@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)): ) ) except Exception as e: - app.state.ENABLED = False + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": return { "model": ( - config_get(app.state.MODEL) - if config_get(app.state.MODEL) - else "dall-e-2" - ) - } - elif app.state.ENGINE == "comfyui": - return { - "model": ( - config_get(app.state.MODEL) if config_get(app.state.MODEL) else "" + app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" ) } + elif app.state.config.ENGINE == "comfyui": + return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: - config_set(app.state.ENABLED, False) + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE in ["openai", "comfyui"]: - config_set(app.state.MODEL, model) - return config_get(app.state.MODEL) + if app.state.config.ENGINE in ["openai", "comfyui"]: + app.state.config.MODEL = model + return app.state.config.MODEL else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, ) return options @@ -397,30 +397,32 @@ def generate_image( user=Depends(get_current_user), ): - width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x"))) + width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x")) r = None try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" data = { - "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "model": ( + app.state.config.MODEL + if app.state.config.MODEL != "" + else "dall-e-2" + ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size - if form_data.size - else config_get(app.state.IMAGE_SIZE) + form_data.size if form_data.size else app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", + url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -440,7 +442,7 @@ def generate_image( return images - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": data = { "prompt": form_data.prompt, @@ -449,8 +451,8 @@ def generate_image( "n": form_data.n, } - if config_get(app.state.IMAGE_STEPS) is not None: - data["steps"] = config_get(app.state.IMAGE_STEPS) + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -458,10 +460,10 @@ def generate_image( data = ImageGenerationPayload(**data) res = comfyui_generate_image( - config_get(app.state.MODEL), + app.state.config.MODEL, data, user.id, - config_get(app.state.COMFYUI_BASE_URL), + app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -488,14 +490,14 @@ def generate_image( "height": height, } - if config_get(app.state.IMAGE_STEPS) is not None: - data["steps"] = config_get(app.state.IMAGE_STEPS) + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7dfadbb0c..cb80eeed2 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,8 +46,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, - config_set, - config_get, + AppConfig, ) from utils.misc import calculate_sha256 @@ -63,11 +62,12 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -98,7 +98,7 @@ async def get_status(): @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): @@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel): @app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.OLLAMA_BASE_URLS, form_data.urls) + app.state.config.OLLAMA_BASE_URLS = form_data.urls - log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} + log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -155,9 +155,7 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [ - fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS) - ] + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS] responses = await asyncio.gather(*tasks) models = { @@ -183,15 +181,14 @@ async def get_ollama_tags( if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] - in config_get(app.state.MODEL_FILTER_LIST), + lambda model: model["name"] in app.state.MODEL_FILTER_LIST, models["models"], ) ) return models return models else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): # returns lowest version tasks = [ - fetch_url(f"{url}/api/version") - for url in config_get(app.state.OLLAMA_BASE_URLS) + fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() @@ -275,7 +271,7 @@ class ModelNameForm(BaseModel): async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -363,7 +359,7 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") r = None @@ -425,7 +421,7 @@ async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -498,7 +494,7 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -545,7 +541,7 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -642,7 +638,7 @@ async def generate_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -692,7 +688,7 @@ def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -761,7 +757,7 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -864,7 +860,7 @@ async def generate_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -973,7 +969,7 @@ async def generate_openai_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -1072,7 +1068,7 @@ async def get_openai_models( } else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1206,7 +1202,7 @@ async def download_model( if url_idx == None: url_idx = 0 - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1225,7 +1221,7 @@ async def download_model( def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): if url_idx == None: url_idx = 0 - ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" @@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # if url_idx == None: # url_idx = 0 -# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] +# url = app.state.config.OLLAMA_BASE_URLS[url_idx] # file_location = os.path.join(UPLOAD_DIR, file.filename) # total_size = file.size @@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): async def deprecated_proxy( path: str, request: Request, user=Depends(get_verified_user) ): - url = config_get(app.state.OLLAMA_BASE_URLS)[0] + url = app.state.config.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}" body = await request.body() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 36fed104c..5112ebb62 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,8 +26,7 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, - config_set, - config_get, + AppConfig, ) from typing import List, Optional @@ -47,11 +46,13 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() + app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.MODELS = {} @@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel): @app.get("/urls") async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): await get_all_models() - config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls) - return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} + app.state.config.OPENAI_API_BASE_URLS = form_data.urls + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.get("/keys") async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/keys/update") async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.OPENAI_API_KEYS, form_data.keys) - return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} + app.state.config.OPENAI_API_KEYS = form_data.keys + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = config_get(app.state.OPENAI_API_BASE_URLS).index( - "https://api.openai.com/v1" - ) + idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = ( - f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}" - ) + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", data=body, headers=headers, stream=True, @@ -187,7 +184,7 @@ def merge_models_lists(model_lists): {**model, "urlIdx": idx} for model in models if "api.openai.com" - not in config_get(app.state.OPENAI_API_BASE_URLS)[idx] + not in app.state.config.OPENAI_API_BASE_URLS[idx] or "gpt" in model["id"] ] ) @@ -199,14 +196,14 @@ async def get_all_models(): log.info("get_all_models()") if ( - len(config_get(app.state.OPENAI_API_KEYS)) == 1 - and config_get(app.state.OPENAI_API_KEYS)[0] == "" + len(app.state.config.OPENAI_API_KEYS) == 1 + and app.state.config.OPENAI_API_KEYS[0] == "" ): models = {"data": []} else: tasks = [ - fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx]) - for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS)) + fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) ] responses = await asyncio.gather(*tasks) @@ -238,19 +235,18 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if config_get(app.state.ENABLE_MODEL_FILTER): + if app.state.ENABLE_MODEL_FILTER: if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] - in config_get(app.state.MODEL_FILTER_LIST), + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, models["data"], ) ) return models return models else: - url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx] + url = app.state.config.OPENAI_API_BASE_URLS[url_idx] r = None @@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) - url = config_get(app.state.OPENAI_API_BASE_URLS)[idx] - key = config_get(app.state.OPENAI_API_KEYS)[idx] + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] target_url = f"{url}/{path}" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f05447a66..d2c3964ae 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,8 +93,7 @@ from config import ( RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, - config_set, - config_get, + AppConfig, ) from constants import ERROR_MESSAGES @@ -104,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() -app.state.TOP_K = RAG_TOP_K -app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config = AppConfig() -app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) -app.state.CHUNK_SIZE = CHUNK_SIZE -app.state.CHUNK_OVERLAP = CHUNK_OVERLAP +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP -app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE -app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY -app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES -app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -135,7 +136,7 @@ def update_embedding_model( embedding_model: str, update_model: bool = False, ): - if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "": + if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -160,22 +161,22 @@ def update_reranking_model( update_embedding_model( - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - config_get(app.state.RAG_RERANKING_MODEL), + app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) origins = ["*"] @@ -202,12 +203,12 @@ class UrlForm(CollectionNameForm): async def get_status(): return { "status": True, - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), - "template": config_get(app.state.RAG_TEMPLATE), - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "template": app.state.config.RAG_TEMPLATE, + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @@ -215,11 +216,11 @@ async def get_status(): async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": config_get(app.state.OPENAI_API_BASE_URL), - "key": config_get(app.state.OPENAI_API_KEY), + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } @@ -228,7 +229,7 @@ async def get_embedding_config(user=Depends(get_admin_user)): async def get_reraanking_config(user=Depends(get_admin_user)): return { "status": True, - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @@ -248,34 +249,34 @@ async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine) - config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model) + app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]: + if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config != None: - config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url) - config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key) + app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url + app.state.config.OPENAI_API_KEY = form_data.openai_config.key - update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True) + update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True app.state.EMBEDDING_FUNCTION = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) return { "status": True, - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": config_get(app.state.OPENAI_API_BASE_URL), - "key": config_get(app.state.OPENAI_API_KEY), + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } except Exception as e: @@ -295,16 +296,16 @@ async def update_reranking_config( form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model) + app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True) + update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True return { "status": True, - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -318,16 +319,14 @@ async def update_reranking_config( async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": config_get( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION - ), + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -352,49 +351,34 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - config_set( - app.state.PDF_EXTRACT_IMAGES, - ( - form_data.pdf_extract_images - if form_data.pdf_extract_images is not None - else config_get(app.state.PDF_EXTRACT_IMAGES) - ), + app.state.config.PDF_EXTRACT_IMAGES = ( + form_data.pdf_extract_images + if form_data.pdf_extract_images is not None + else app.state.config.PDF_EXTRACT_IMAGES ) - config_set( - app.state.CHUNK_SIZE, - ( - form_data.chunk.chunk_size - if form_data.chunk is not None - else config_get(app.state.CHUNK_SIZE) - ), + app.state.config.CHUNK_SIZE = ( + form_data.chunk.chunk_size + if form_data.chunk is not None + else app.state.config.CHUNK_SIZE ) - config_set( - app.state.CHUNK_OVERLAP, - ( - form_data.chunk.chunk_overlap - if form_data.chunk is not None - else config_get(app.state.CHUNK_OVERLAP) - ), + app.state.config.CHUNK_OVERLAP = ( + form_data.chunk.chunk_overlap + if form_data.chunk is not None + else app.state.config.CHUNK_OVERLAP ) - config_set( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ( - form_data.web_loader_ssl_verification - if form_data.web_loader_ssl_verification != None - else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION) - ), + app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web_loader_ssl_verification + if form_data.web_loader_ssl_verification != None + else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) - config_set( - app.state.YOUTUBE_LOADER_LANGUAGE, - ( - form_data.youtube.language - if form_data.youtube is not None - else config_get(app.state.YOUTUBE_LOADER_LANGUAGE) - ), + app.state.config.YOUTUBE_LOADER_LANGUAGE = ( + form_data.youtube.language + if form_data.youtube is not None + else app.state.config.YOUTUBE_LOADER_LANGUAGE ) app.state.YOUTUBE_LOADER_TRANSLATION = ( @@ -405,16 +389,14 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ return { "status": True, - "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": config_get( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION - ), + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -424,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ async def get_rag_template(user=Depends(get_current_user)): return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), + "template": app.state.config.RAG_TEMPLATE, } @@ -432,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)): async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), - "k": config_get(app.state.TOP_K), - "r": config_get(app.state.RELEVANCE_THRESHOLD), - "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -450,22 +432,20 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - config_set( - app.state.RAG_TEMPLATE, + app.state.config.RAG_TEMPLATE = ( form_data.template if form_data.template else RAG_TEMPLATE, ) - config_set(app.state.TOP_K, form_data.k if form_data.k else 4) - config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0) - config_set( - app.state.ENABLE_RAG_HYBRID_SEARCH, + app.state.config.TOP_K = form_data.k if form_data.k else 4 + app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False, ) return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), - "k": config_get(app.state.TOP_K), - "r": config_get(app.state.RELEVANCE_THRESHOLD), - "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -483,17 +463,15 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( - form_data.r - if form_data.r - else config_get(app.state.RELEVANCE_THRESHOLD) + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: @@ -501,7 +479,7 @@ def query_doc_handler( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -525,17 +503,15 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( - form_data.r - if form_data.r - else config_get(app.state.RELEVANCE_THRESHOLD) + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: @@ -543,7 +519,7 @@ def query_collection_handler( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: @@ -560,8 +536,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, - language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE), - translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION), + language=app.state.config.YOUTUBE_LOADER_LANGUAGE, + translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) data = loader.load() @@ -589,7 +565,7 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): try: loader = get_web_loader( form_data.url, - verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION), + verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ) data = loader.load() @@ -645,8 +621,8 @@ def resolve_hostname(hostname): def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=config_get(app.state.CHUNK_SIZE), - chunk_overlap=config_get(app.state.CHUNK_OVERLAP), + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) @@ -663,8 +639,8 @@ def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=config_get(app.state.CHUNK_SIZE), - chunk_overlap=config_get(app.state.CHUNK_OVERLAP), + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) @@ -687,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -766,7 +742,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): if file_ext == "pdf": loader = PyPDFLoader( - file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES) + file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES ) elif file_ext == "csv": loader = CSVLoader(file_path) diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 2bed33543..755e3911b 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -22,21 +22,23 @@ from config import ( WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, - config_get, + AppConfig, ) app = FastAPI() origins = ["*"] -app.state.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN +app.state.config = AppConfig() -app.state.DEFAULT_MODELS = DEFAULT_MODELS -app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE -app.state.USER_PERMISSIONS = USER_PERMISSIONS -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS +app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.add_middleware( @@ -63,6 +65,6 @@ async def get_status(): return { "status": True, "auth": WEBUI_AUTH, - "default_models": config_get(app.state.DEFAULT_MODELS), - "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS), + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 0bc4967f9..998e74659 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -33,7 +33,7 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set +from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER router = APIRouter() @@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) return { @@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH: + if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) @@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): role = ( "admin" if Users.get_num_users() == 0 - else config_get(request.app.state.DEFAULT_USER_ROLE) + else request.app.state.config.DEFAULT_USER_ROLE ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( @@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration( - config_get(request.app.state.JWT_EXPIRES_IN) - ), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) # response.set_cookie(key='token', value=token, httponly=True) - if config_get(request.app.state.WEBHOOK_URL): + if request.app.state.config.WEBHOOK_URL: post_webhook( - config_get(request.app.state.WEBHOOK_URL), + request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", @@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/signup/enabled", response_model=bool) async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.ENABLE_SIGNUP) + return request.app.state.config.ENABLE_SIGNUP @router.get("/signup/enabled/toggle", response_model=bool) async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): - config_set( - request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP) - ) - return config_get(request.app.state.ENABLE_SIGNUP) + request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP + return request.app.state.config.ENABLE_SIGNUP ############################ @@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): @router.get("/signup/user/role") async def get_default_user_role(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.DEFAULT_USER_ROLE) + return request.app.state.config.DEFAULT_USER_ROLE class UpdateRoleForm(BaseModel): @@ -308,8 +304,8 @@ async def update_default_user_role( request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) ): if form_data.role in ["pending", "user", "admin"]: - config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role) - return config_get(request.app.state.DEFAULT_USER_ROLE) + request.app.state.config.DEFAULT_USER_ROLE = form_data.role + return request.app.state.config.DEFAULT_USER_ROLE ############################ @@ -319,7 +315,7 @@ async def update_default_user_role( @router.get("/token/expires") async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.JWT_EXPIRES_IN) + return request.app.state.config.JWT_EXPIRES_IN class UpdateJWTExpiresDurationForm(BaseModel): @@ -336,10 +332,10 @@ async def update_token_expires_duration( # Check if the input string matches the pattern if re.match(pattern, form_data.duration): - config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration) - return config_get(request.app.state.JWT_EXPIRES_IN) + request.app.state.config.JWT_EXPIRES_IN = form_data.duration + return request.app.state.config.JWT_EXPIRES_IN else: - return config_get(request.app.state.JWT_EXPIRES_IN) + return request.app.state.config.JWT_EXPIRES_IN ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index d726cd2dc..143ed5e0a 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -9,7 +9,6 @@ import time import uuid from apps.web.models.users import Users -from config import config_set, config_get from utils.utils import ( get_password_hash, @@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel): async def set_global_default_models( request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - config_set(request.app.state.DEFAULT_MODELS, form_data.models) - return config_get(request.app.state.DEFAULT_MODELS) + request.app.state.config.DEFAULT_MODELS = form_data.models + return request.app.state.config.DEFAULT_MODELS @router.post("/default/suggestions", response_model=List[PromptSuggestion]) @@ -56,5 +55,5 @@ async def set_global_default_suggestions( user=Depends(get_admin_user), ): data = form_data.model_dump() - config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"]) - return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS) + request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] + return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 302432540..d87854e89 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -15,7 +15,7 @@ from apps.web.models.auths import Auths from utils.utils import get_current_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS, config_set, config_get +from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) @router.get("/permissions/user") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.USER_PERMISSIONS) + return request.app.state.config.USER_PERMISSIONS @router.post("/permissions/user") async def update_user_permissions( request: Request, form_data: dict, user=Depends(get_admin_user) ): - config_set(request.app.state.USER_PERMISSIONS, form_data) - return config_get(request.app.state.USER_PERMISSIONS) + request.app.state.config.USER_PERMISSIONS = form_data + return request.app.state.config.USER_PERMISSIONS ############################ diff --git a/backend/config.py b/backend/config.py index 9e7a9ef90..845a812ce 100644 --- a/backend/config.py +++ b/backend/config.py @@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]): self.config_value = self.value -def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True): - if isinstance(config, WrappedConfig): - config.value = value - if save_config: - config.save() - else: - config = value +class AppConfig: + _state: dict[str, WrappedConfig] + def __init__(self): + super().__setattr__("_state", {}) -def config_get(config: Union[WrappedConfig[T], T]) -> T: - if isinstance(config, WrappedConfig): - return config.value - return config + def __setattr__(self, key, value): + if isinstance(value, WrappedConfig): + self._state[key] = value + else: + self._state[key].value = value + self._state[key].save() + + def __getattr__(self, key): + return self._state[key].value #################################### diff --git a/backend/main.py b/backend/main.py index 6f94a8dad..e2d7e18a3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,8 +58,7 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - config_get, - config_set, + AppConfig, ) from constants import ERROR_MESSAGES @@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config = AppConfig() +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.WEBHOOK_URL = WEBHOOK_URL origins = ["*"] @@ -245,11 +245,9 @@ async def get_app_config(): "version": VERSION, "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": config_get(images_app.state.ENABLED), - "default_models": config_get(webui_app.state.DEFAULT_MODELS), - "default_prompt_suggestions": config_get( - webui_app.state.DEFAULT_PROMPT_SUGGESTIONS - ), + "images": images_app.state.config.ENABLED, + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -258,8 +256,8 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): return { - "enabled": config_get(app.state.ENABLE_MODEL_FILTER), - "models": config_get(app.state.MODEL_FILTER_LIST), + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled) - config_set(app.state.MODEL_FILTER_LIST, form_data.models) + app.state.config.ENABLE_MODEL_FILTER, form_data.enabled + app.state.config.MODEL_FILTER_LIST, form_data.models - ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST return { - "enabled": config_get(app.state.ENABLE_MODEL_FILTER), - "models": config_get(app.state.MODEL_FILTER_LIST), + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": config_get(app.state.WEBHOOK_URL), + "url": app.state.config.WEBHOOK_URL, } @@ -303,12 +301,12 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - config_set(app.state.WEBHOOK_URL, form_data.url) + app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL) + webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return { - "url": config_get(app.state.WEBHOOK_URL), + "url": app.state.config.WEBHOOK_URL, }