diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 87732d7bc..0f65a551e 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -45,6 +45,7 @@ from config import ( AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_VOICE, + AppConfig, ) log = logging.getLogger(__name__) @@ -59,11 +60,11 @@ app.add_middleware( allow_headers=["*"], ) - -app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY -app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL -app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE +app.state.config = AppConfig() +app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY +app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL +app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -97,17 +98,17 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key - app.state.OPENAI_API_MODEL = form_data.model - app.state.OPENAI_API_VOICE = form_data.speaker + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key + app.state.config.OPENAI_API_MODEL = form_data.model + app.state.config.OPENAI_API_VOICE = form_data.speaker return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, stream=True, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index f45cf0d12..1c309439d 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,6 +42,7 @@ from config import ( IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, + AppConfig, ) @@ -60,26 +61,31 @@ app.add_middleware( allow_headers=["*"], ) -app.state.ENGINE = IMAGE_GENERATION_ENGINE -app.state.ENABLED = ENABLE_IMAGE_GENERATION +app.state.config = AppConfig() -app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLED = ENABLE_IMAGE_GENERATION -app.state.MODEL = IMAGE_GENERATION_MODEL +app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.MODEL = IMAGE_GENERATION_MODEL -app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.IMAGE_SIZE = IMAGE_SIZE -app.state.IMAGE_STEPS = IMAGE_STEPS +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + return { + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, + } class ConfigUpdateForm(BaseModel): @@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.ENGINE = form_data.engine - app.state.ENABLED = form_data.enabled - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + app.state.config.ENGINE = form_data.engine + app.state.config.ENABLED = form_data.enabled + return { + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, + } class EngineUrlUpdateForm(BaseModel): @@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel): @app.get("/url") async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, } @@ -113,29 +122,29 @@ async def update_engine_url( ): if form_data.AUTOMATIC1111_BASE_URL == None: - app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL + app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) - app.state.AUTOMATIC1111_BASE_URL = url + app.state.config.AUTOMATIC1111_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.COMFYUI_BASE_URL == None: - app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) - app.state.COMFYUI_BASE_URL = url + app.state.config.COMFYUI_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "status": True, } @@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/openai/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -160,13 +169,13 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel): @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.IMAGE_SIZE} + return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} @app.post("/size/update") @@ -185,9 +194,9 @@ async def update_image_size( ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): - app.state.IMAGE_SIZE = form_data.size + app.state.config.IMAGE_SIZE = form_data.size return { - "IMAGE_SIZE": app.state.IMAGE_SIZE, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, "status": True, } else: @@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel): @app.get("/steps") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.IMAGE_STEPS} + return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} @app.post("/steps/update") @@ -211,9 +220,9 @@ async def update_image_size( form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) ): if form_data.steps >= 0: - app.state.IMAGE_STEPS = form_data.steps + app.state.config.IMAGE_STEPS = form_data.steps return { - "IMAGE_STEPS": app.state.IMAGE_STEPS, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, "status": True, } else: @@ -226,14 +235,14 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": - r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") + r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() return list( @@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)): else: r = requests.get( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" ) models = r.json() return list( @@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)): ) ) except Exception as e: - app.state.ENABLED = False + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - if app.state.ENGINE == "openai": - return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} - elif app.state.ENGINE == "comfyui": - return {"model": app.state.MODEL if app.state.MODEL else ""} + if app.state.config.ENGINE == "openai": + return { + "model": ( + app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" + ) + } + elif app.state.config.ENGINE == "comfyui": + return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: - app.state.ENABLED = False + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE == "openai": - app.state.MODEL = model - return app.state.MODEL - if app.state.ENGINE == "comfyui": - app.state.MODEL = model - return app.state.MODEL + if app.state.config.ENGINE in ["openai", "comfyui"]: + app.state.config.MODEL = model + return app.state.config.MODEL else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, ) return options @@ -382,26 +397,32 @@ def generate_image( user=Depends(get_current_user), ): - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x")) r = None try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" data = { - "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "model": ( + app.state.config.MODEL + if app.state.config.MODEL != "" + else "dall-e-2" + ), "prompt": form_data.prompt, "n": form_data.n, - "size": form_data.size if form_data.size else app.state.IMAGE_SIZE, + "size": ( + form_data.size if form_data.size else app.state.config.IMAGE_SIZE + ), "response_format": "b64_json", } r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", + url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -421,7 +442,7 @@ def generate_image( return images - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": data = { "prompt": form_data.prompt, @@ -430,19 +451,19 @@ def generate_image( "n": form_data.n, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt data = ImageGenerationPayload(**data) res = comfyui_generate_image( - app.state.MODEL, + app.state.config.MODEL, data, user.id, - app.state.COMFYUI_BASE_URL, + app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -469,14 +490,14 @@ def generate_image( "height": height, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 042d0336d..cb80eeed2 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, + AppConfig, ) from utils.misc import calculate_sha256 @@ -61,11 +62,12 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -96,7 +98,7 @@ async def get_status(): @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): @@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel): @app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_BASE_URLS = form_data.urls + app.state.config.OLLAMA_BASE_URLS = form_data.urls - log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -153,7 +155,7 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS] responses = await asyncio.gather(*tasks) models = { @@ -186,7 +188,7 @@ async def get_ollama_tags( return models return models else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx == None: # returns lowest version - tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + tasks = [ + fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS + ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() @@ -267,7 +271,7 @@ class ModelNameForm(BaseModel): async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -355,7 +359,7 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") r = None @@ -417,7 +421,7 @@ async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -490,7 +494,7 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -537,7 +541,7 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -634,7 +638,7 @@ async def generate_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -684,7 +688,7 @@ def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -753,7 +757,7 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -856,7 +860,7 @@ async def generate_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -965,7 +969,7 @@ async def generate_openai_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -1064,7 +1068,7 @@ async def get_openai_models( } else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1198,7 +1202,7 @@ async def download_model( if url_idx == None: url_idx = 0 - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1217,7 +1221,7 @@ async def download_model( def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): if url_idx == None: url_idx = 0 - ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] + ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" @@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # if url_idx == None: # url_idx = 0 -# url = app.state.OLLAMA_BASE_URLS[url_idx] +# url = app.state.config.OLLAMA_BASE_URLS[url_idx] # file_location = os.path.join(UPLOAD_DIR, file.filename) # total_size = file.size @@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): async def deprecated_proxy( path: str, request: Request, user=Depends(get_verified_user) ): - url = app.state.OLLAMA_BASE_URLS[0] + url = app.state.config.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}" body = await request.body() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index b38c2bc2d..fb16a579b 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,6 +26,7 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, + AppConfig, ) from typing import List, Optional @@ -45,11 +46,13 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() + app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.MODELS = {} @@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel): @app.get("/urls") async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): await get_all_models() - app.state.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + app.state.config.OPENAI_API_BASE_URLS = form_data.urls + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.get("/keys") async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/keys/update") async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + app.state.config.OPENAI_API_KEYS = form_data.keys + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -114,7 +117,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" headers["Content-Type"] = "application/json" if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]: headers['HTTP-Referer'] = "https://openwebui.com/" @@ -122,7 +125,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", data=body, headers=headers, stream=True, @@ -182,7 +185,8 @@ def merge_models_lists(model_lists): [ {**model, "urlIdx": idx} for model in models - if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] + if "api.openai.com" + not in app.state.config.OPENAI_API_BASE_URLS[idx] or "gpt" in model["id"] ] ) @@ -193,12 +197,15 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": + if ( + len(app.state.config.OPENAI_API_KEYS) == 1 + and app.state.config.OPENAI_API_KEYS[0] == "" + ): models = {"data": []} else: tasks = [ - fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) + fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) ] responses = await asyncio.gather(*tasks) @@ -241,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use return models return models else: - url = app.state.OPENAI_API_BASE_URLS[url_idx] + url = app.state.config.OPENAI_API_BASE_URLS[url_idx] r = None @@ -305,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) - url = app.state.OPENAI_API_BASE_URLS[idx] - key = app.state.OPENAI_API_KEYS[idx] + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] target_url = f"{url}/{path}" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2e2a8e209..d2c3964ae 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,6 +93,7 @@ from config import ( RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, + AppConfig, ) from constants import ERROR_MESSAGES @@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() -app.state.TOP_K = RAG_TOP_K -app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config = AppConfig() -app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) -app.state.CHUNK_SIZE = CHUNK_SIZE -app.state.CHUNK_OVERLAP = CHUNK_OVERLAP +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP -app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE -app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY -app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES -app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -133,7 +136,7 @@ def update_embedding_model( embedding_model: str, update_model: bool = False, ): - if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": + if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -158,22 +161,22 @@ def update_reranking_model( update_embedding_model( - app.state.RAG_EMBEDDING_MODEL, + app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - app.state.RAG_RERANKING_MODEL, + app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) origins = ["*"] @@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm): async def get_status(): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, - "template": app.state.RAG_TEMPLATE, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "template": app.state.config.RAG_TEMPLATE, + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @@ -213,18 +216,21 @@ async def get_status(): async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } @app.get("/reranking") async def get_reraanking_config(user=Depends(get_admin_user)): - return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} + return { + "status": True, + "reranking_model": app.state.config.RAG_RERANKING_MODEL, + } class OpenAIConfigForm(BaseModel): @@ -243,34 +249,34 @@ async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config != None: - app.state.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.OPENAI_API_KEY = form_data.openai_config.key + app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url + app.state.config.OPENAI_API_KEY = form_data.openai_config.key - update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) + update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } except Exception as e: @@ -290,16 +296,16 @@ async def update_reranking_config( form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.RAG_RERANKING_MODEL = form_data.reranking_model + app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(app.state.RAG_RERANKING_MODEL, True) + update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True return { "status": True, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -313,14 +319,14 @@ async def update_reranking_config( async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.PDF_EXTRACT_IMAGES = ( + app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images - if form_data.pdf_extract_images != None - else app.state.PDF_EXTRACT_IMAGES + if form_data.pdf_extract_images is not None + else app.state.config.PDF_EXTRACT_IMAGES ) - app.state.CHUNK_SIZE = ( - form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE + app.state.config.CHUNK_SIZE = ( + form_data.chunk.chunk_size + if form_data.chunk is not None + else app.state.config.CHUNK_SIZE ) - app.state.CHUNK_OVERLAP = ( + app.state.config.CHUNK_OVERLAP = ( form_data.chunk.chunk_overlap - if form_data.chunk != None - else app.state.CHUNK_OVERLAP + if form_data.chunk is not None + else app.state.config.CHUNK_OVERLAP ) - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( form_data.web_loader_ssl_verification if form_data.web_loader_ssl_verification != None - else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) - app.state.YOUTUBE_LOADER_LANGUAGE = ( + app.state.config.YOUTUBE_LOADER_LANGUAGE = ( form_data.youtube.language - if form_data.youtube != None - else app.state.YOUTUBE_LOADER_LANGUAGE + if form_data.youtube is not None + else app.state.config.YOUTUBE_LOADER_LANGUAGE ) app.state.YOUTUBE_LOADER_TRANSLATION = ( form_data.youtube.translation - if form_data.youtube != None + if form_data.youtube is not None else app.state.YOUTUBE_LOADER_TRANSLATION ) return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ async def get_rag_template(user=Depends(get_current_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, + "template": app.state.config.RAG_TEMPLATE, } @@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)): async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE - app.state.TOP_K = form_data.k if form_data.k else 4 - app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False + app.state.config.RAG_TEMPLATE = ( + form_data.template if form_data.template else RAG_TEMPLATE, + ) + app.state.config.TOP_K = form_data.k if form_data.k else 4 + app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + form_data.hybrid if form_data.hybrid else False, + ) return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -451,21 +463,23 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -489,21 +503,23 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: @@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, - language=app.state.YOUTUBE_LOADER_LANGUAGE, + language=app.state.config.YOUTUBE_LOADER_LANGUAGE, translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) data = loader.load() @@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( - form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + form_data.url, + verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ) data = loader.load() @@ -604,8 +621,8 @@ def resolve_hostname(hostname): def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) @@ -622,8 +639,8 @@ def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) @@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) + loader = PyPDFLoader( + file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES + ) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 66cdfb3d4..755e3911b 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -21,20 +21,24 @@ from config import ( USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + JWT_EXPIRES_IN, + AppConfig, ) app = FastAPI() origins = ["*"] -app.state.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.JWT_EXPIRES_IN = "-1" +app.state.config = AppConfig() -app.state.DEFAULT_MODELS = DEFAULT_MODELS -app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE -app.state.USER_PERMISSIONS = USER_PERMISSIONS -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS +app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.add_middleware( @@ -61,6 +65,6 @@ async def get_status(): return { "status": True, "auth": WEBUI_AUTH, - "default_models": app.state.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 9fa962dda..998e74659 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) return { @@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: + if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) @@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): role = ( "admin" if Users.get_num_users() == 0 - else request.app.state.DEFAULT_USER_ROLE + else request.app.state.config.DEFAULT_USER_ROLE ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( @@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) # response.set_cookie(key='token', value=token, httponly=True) - if request.app.state.WEBHOOK_URL: + if request.app.state.config.WEBHOOK_URL: post_webhook( - request.app.state.WEBHOOK_URL, + request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", @@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/signup/enabled", response_model=bool) async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): - return request.app.state.ENABLE_SIGNUP + return request.app.state.config.ENABLE_SIGNUP @router.get("/signup/enabled/toggle", response_model=bool) async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): - request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP - return request.app.state.ENABLE_SIGNUP + request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP + return request.app.state.config.ENABLE_SIGNUP ############################ @@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): @router.get("/signup/user/role") async def get_default_user_role(request: Request, user=Depends(get_admin_user)): - return request.app.state.DEFAULT_USER_ROLE + return request.app.state.config.DEFAULT_USER_ROLE class UpdateRoleForm(BaseModel): @@ -304,8 +304,8 @@ async def update_default_user_role( request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) ): if form_data.role in ["pending", "user", "admin"]: - request.app.state.DEFAULT_USER_ROLE = form_data.role - return request.app.state.DEFAULT_USER_ROLE + request.app.state.config.DEFAULT_USER_ROLE = form_data.role + return request.app.state.config.DEFAULT_USER_ROLE ############################ @@ -315,7 +315,7 @@ async def update_default_user_role( @router.get("/token/expires") async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): - return request.app.state.JWT_EXPIRES_IN + return request.app.state.config.JWT_EXPIRES_IN class UpdateJWTExpiresDurationForm(BaseModel): @@ -332,10 +332,10 @@ async def update_token_expires_duration( # Check if the input string matches the pattern if re.match(pattern, form_data.duration): - request.app.state.JWT_EXPIRES_IN = form_data.duration - return request.app.state.JWT_EXPIRES_IN + request.app.state.config.JWT_EXPIRES_IN = form_data.duration + return request.app.state.config.JWT_EXPIRES_IN else: - return request.app.state.JWT_EXPIRES_IN + return request.app.state.config.JWT_EXPIRES_IN ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 0bad55a6a..143ed5e0a 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel): async def set_global_default_models( request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - request.app.state.DEFAULT_MODELS = form_data.models - return request.app.state.DEFAULT_MODELS + request.app.state.config.DEFAULT_MODELS = form_data.models + return request.app.state.config.DEFAULT_MODELS @router.post("/default/suggestions", response_model=List[PromptSuggestion]) @@ -55,5 +55,5 @@ async def set_global_default_suggestions( user=Depends(get_admin_user), ): data = form_data.model_dump() - request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] - return request.app.state.DEFAULT_PROMPT_SUGGESTIONS + request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] + return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 59f6c21b7..d87854e89 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) @router.get("/permissions/user") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return request.app.state.USER_PERMISSIONS + return request.app.state.config.USER_PERMISSIONS @router.post("/permissions/user") async def update_user_permissions( request: Request, form_data: dict, user=Depends(get_admin_user) ): - request.app.state.USER_PERMISSIONS = form_data - return request.app.state.USER_PERMISSIONS + request.app.state.config.USER_PERMISSIONS = form_data + return request.app.state.config.USER_PERMISSIONS ############################ diff --git a/backend/config.py b/backend/config.py index 5c6247a9f..112edba90 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,7 @@ import chromadb from chromadb import Settings from base64 import b64encode from bs4 import BeautifulSoup +from typing import TypeVar, Generic, Union from pathlib import Path import json @@ -17,7 +18,6 @@ import shutil from secrets import token_bytes from constants import ERROR_MESSAGES - #################################### # Load .env file #################################### @@ -71,7 +71,6 @@ for source in log_sources: log.setLevel(SRC_LOG_LEVELS["CONFIG"]) - WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") if WEBUI_NAME != "Open WebUI": WEBUI_NAME += " (Open WebUI)" @@ -161,16 +160,6 @@ CHANGELOG = changelog_json WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") -#################################### -# WEBUI_AUTH (Required for security) -#################################### - -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) - - #################################### # DATA/FRONTEND BUILD DIR #################################### @@ -184,6 +173,108 @@ try: except: CONFIG_DATA = {} + +#################################### +# Config helpers +#################################### + + +def save_config(): + try: + with open(f"{DATA_DIR}/config.json", "w") as f: + json.dump(CONFIG_DATA, f, indent="\t") + except Exception as e: + log.exception(e) + + +def get_config_value(config_path: str): + path_parts = config_path.split(".") + cur_config = CONFIG_DATA + for key in path_parts: + if key in cur_config: + cur_config = cur_config[key] + else: + return None + return cur_config + + +T = TypeVar("T") + + +class PersistentConfig(Generic[T]): + def __init__(self, env_name: str, config_path: str, env_value: T): + self.env_name = env_name + self.config_path = config_path + self.env_value = env_value + self.config_value = get_config_value(config_path) + if self.config_value is not None: + log.info(f"'{env_name}' loaded from config.json") + self.value = self.config_value + else: + self.value = env_value + + def __str__(self): + return str(self.value) + + @property + def __dict__(self): + raise TypeError( + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." + ) + + def __getattribute__(self, item): + if item == "__dict__": + raise TypeError( + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." + ) + return super().__getattribute__(item) + + def save(self): + # Don't save if the value is the same as the env value and the config value + if self.env_value == self.value: + if self.config_value == self.value: + return + log.info(f"Saving '{self.env_name}' to config.json") + path_parts = self.config_path.split(".") + config = CONFIG_DATA + for key in path_parts[:-1]: + if key not in config: + config[key] = {} + config = config[key] + config[path_parts[-1]] = self.value + save_config() + self.config_value = self.value + + +class AppConfig: + _state: dict[str, PersistentConfig] + + def __init__(self): + super().__setattr__("_state", {}) + + def __setattr__(self, key, value): + if isinstance(value, PersistentConfig): + self._state[key] = value + else: + self._state[key].value = value + self._state[key].save() + + def __getattr__(self, key): + return self._state[key].value + + +#################################### +# WEBUI_AUTH (Required for security) +#################################### + +WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) +JWT_EXPIRES_IN = PersistentConfig( + "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") +) + #################################### # Static DIR #################################### @@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] - +OLLAMA_BASE_URLS = PersistentConfig( + "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS +) #################################### # OPENAI_API @@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] - +OPENAI_API_KEYS = PersistentConfig( + "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS +) OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = ( @@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [ url.strip() if url != "" else "https://api.openai.com/v1" for url in OPENAI_API_BASE_URLS.split(";") ] +OPENAI_API_BASE_URLS = PersistentConfig( + "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS +) OPENAI_API_KEY = "" try: - OPENAI_API_KEY = OPENAI_API_KEYS[ - OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + OPENAI_API_KEY = OPENAI_API_KEYS.value[ + OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" - #################################### # WEBUI #################################### -ENABLE_SIGNUP = ( - False - if WEBUI_AUTH == False - else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" +ENABLE_SIGNUP = PersistentConfig( + "ENABLE_SIGNUP", + "ui.enable_signup", + ( + False + if not WEBUI_AUTH + else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" + ), +) +DEFAULT_MODELS = PersistentConfig( + "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) - -DEFAULT_PROMPT_SUGGESTIONS = ( - CONFIG_DATA["ui"]["prompt_suggestions"] - if "ui" in CONFIG_DATA - and "prompt_suggestions" in CONFIG_DATA["ui"] - and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list - else [ +DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( + "DEFAULT_PROMPT_SUGGESTIONS", + "ui.prompt_suggestions", + [ { "title": ["Help me study", "vocabulary for a college entrance exam"], "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", @@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = ( "title": ["Overcome procrastination", "give me tips"], "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", }, - ] + ], ) - -DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") +DEFAULT_USER_ROLE = PersistentConfig( + "DEFAULT_USER_ROLE", + "ui.default_user_role", + os.getenv("DEFAULT_USER_ROLE", "pending"), +) USER_PERMISSIONS_CHAT_DELETION = ( os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) -USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} +USER_PERMISSIONS = PersistentConfig( + "USER_PERMISSIONS", + "ui.user_permissions", + {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, +) -ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true" +ENABLE_MODEL_FILTER = PersistentConfig( + "ENABLE_MODEL_FILTER", + "model_filter.enable", + os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", +) MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] +MODEL_FILTER_LIST = PersistentConfig( + "MODEL_FILTER_LIST", + "model_filter.list", + [model.strip() for model in MODEL_FILTER_LIST.split(";")], +) -WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") +WEBHOOK_URL = PersistentConfig( + "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") +) ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" @@ -458,26 +575,45 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) -RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) -RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) - -ENABLE_RAG_HYBRID_SEARCH = ( - os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +RAG_TOP_K = PersistentConfig( + "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) +) +RAG_RELEVANCE_THRESHOLD = PersistentConfig( + "RAG_RELEVANCE_THRESHOLD", + "rag.relevance_threshold", + float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) - -ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" +ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( + "ENABLE_RAG_HYBRID_SEARCH", + "rag.enable_hybrid_search", + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) -RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") - -PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" - -RAG_EMBEDDING_MODEL = os.environ.get( - "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( + "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", + "rag.enable_web_loader_ssl_verification", + os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", ) -log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), + +RAG_EMBEDDING_ENGINE = PersistentConfig( + "RAG_EMBEDDING_ENGINE", + "rag.embedding_engine", + os.environ.get("RAG_EMBEDDING_ENGINE", ""), +) + +PDF_EXTRACT_IMAGES = PersistentConfig( + "PDF_EXTRACT_IMAGES", + "rag.pdf_extract_images", + os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", +) + +RAG_EMBEDDING_MODEL = PersistentConfig( + "RAG_EMBEDDING_MODEL", + "rag.embedding_model", + os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), +) +log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"), RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") -if not RAG_RERANKING_MODEL == "": - log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), +RAG_RERANKING_MODEL = PersistentConfig( + "RAG_RERANKING_MODEL", + "rag.reranking_model", + os.environ.get("RAG_RERANKING_MODEL", ""), +) +if RAG_RERANKING_MODEL.value != "": + log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"), RAG_RERANKING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true": else: DEVICE_TYPE = "cpu" - -CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) -CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) +CHUNK_SIZE = PersistentConfig( + "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) +) +CHUNK_OVERLAP = PersistentConfig( + "CHUNK_OVERLAP", + "rag.chunk_overlap", + int(os.environ.get("CHUNK_OVERLAP", "100")), +) DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. @@ -545,16 +690,32 @@ And answer according to the language of the user's question. Given the context information, answer the query. Query: [query]""" -RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) +RAG_TEMPLATE = PersistentConfig( + "RAG_TEMPLATE", + "rag.template", + os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), +) -RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +RAG_OPENAI_API_BASE_URL = PersistentConfig( + "RAG_OPENAI_API_BASE_URL", + "rag.openai_api_base_url", + os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +RAG_OPENAI_API_KEY = PersistentConfig( + "RAG_OPENAI_API_KEY", + "rag.openai_api_key", + os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), +) ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) -YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",") +YOUTUBE_LOADER_LANGUAGE = PersistentConfig( + "YOUTUBE_LOADER_LANGUAGE", + "rag.youtube_loader_language", + os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), +) #################################### # Transcribe @@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = ( # Images #################################### -IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "") - -ENABLE_IMAGE_GENERATION = ( - os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" +IMAGE_GENERATION_ENGINE = PersistentConfig( + "IMAGE_GENERATION_ENGINE", + "image_generation.engine", + os.getenv("IMAGE_GENERATION_ENGINE", ""), ) -AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") -COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") - -IMAGES_OPENAI_API_BASE_URL = os.getenv( - "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL +ENABLE_IMAGE_GENERATION = PersistentConfig( + "ENABLE_IMAGE_GENERATION", + "image_generation.enable", + os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", +) +AUTOMATIC1111_BASE_URL = PersistentConfig( + "AUTOMATIC1111_BASE_URL", + "image_generation.automatic1111.base_url", + os.getenv("AUTOMATIC1111_BASE_URL", ""), ) -IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY) -IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512") +COMFYUI_BASE_URL = PersistentConfig( + "COMFYUI_BASE_URL", + "image_generation.comfyui.base_url", + os.getenv("COMFYUI_BASE_URL", ""), +) -IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50)) +IMAGES_OPENAI_API_BASE_URL = PersistentConfig( + "IMAGES_OPENAI_API_BASE_URL", + "image_generation.openai.api_base_url", + os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +IMAGES_OPENAI_API_KEY = PersistentConfig( + "IMAGES_OPENAI_API_KEY", + "image_generation.openai.api_key", + os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), +) -IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") +IMAGE_SIZE = PersistentConfig( + "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") +) + +IMAGE_STEPS = PersistentConfig( + "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) +) + +IMAGE_GENERATION_MODEL = PersistentConfig( + "IMAGE_GENERATION_MODEL", + "image_generation.model", + os.getenv("IMAGE_GENERATION_MODEL", ""), +) #################################### # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) -AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1") -AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy") +AUDIO_OPENAI_API_BASE_URL = PersistentConfig( + "AUDIO_OPENAI_API_BASE_URL", + "audio.openai.api_base_url", + os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +AUDIO_OPENAI_API_KEY = PersistentConfig( + "AUDIO_OPENAI_API_KEY", + "audio.openai.api_key", + os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), +) +AUDIO_OPENAI_API_MODEL = PersistentConfig( + "AUDIO_OPENAI_API_MODEL", + "audio.openai.api_model", + os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), +) +AUDIO_OPENAI_API_VOICE = PersistentConfig( + "AUDIO_OPENAI_API_VOICE", + "audio.openai.api_voice", + os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), +) #################################### # LiteLLM diff --git a/backend/main.py b/backend/main.py index 2d8d9ed68..8b7f9af69 100644 --- a/backend/main.py +++ b/backend/main.py @@ -59,6 +59,7 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, + AppConfig, ) from constants import ERROR_MESSAGES @@ -107,10 +108,11 @@ app = FastAPI( docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan ) -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config = AppConfig() +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.WEBHOOK_URL = WEBHOOK_URL origins = ["*"] @@ -250,9 +252,9 @@ async def get_app_config(): "version": VERSION, "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": images_app.state.ENABLED, - "default_models": webui_app.state.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, + "images": images_app.state.config.ENABLED, + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -261,8 +263,8 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @@ -275,28 +277,28 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - app.state.ENABLE_MODEL_FILTER = form_data.enabled - app.state.MODEL_FILTER_LIST = form_data.models + app.state.config.ENABLE_MODEL_FILTER, form_data.enabled + app.state.config.MODEL_FILTER_LIST, form_data.models - ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": app.state.WEBHOOK_URL, + "url": app.state.config.WEBHOOK_URL, } @@ -306,12 +308,12 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - app.state.WEBHOOK_URL = form_data.url + app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL + webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return { - "url": app.state.WEBHOOK_URL, + "url": app.state.config.WEBHOOK_URL, }