diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 2beec59f7..03e115cc0 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -42,10 +42,10 @@ router = APIRouter() async def get_config(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, - "engine": request.app.state.config.ENGINE, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, @@ -93,11 +93,13 @@ class ConfigForm(BaseModel): async def update_config( request: Request, form_data: ConfigForm, user=Depends(get_admin_user) ): - request.app.state.config.ENGINE = form_data.engine + request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled - request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( + form_data.openai.OPENAI_API_BASE_URL + ) + request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL @@ -132,10 +134,10 @@ async def update_config( return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, - "engine": request.app.state.config.ENGINE, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, @@ -166,7 +168,7 @@ def get_automatic1111_api_auth(request: Request): @router.get("/config/url/verify") async def verify_url(request: Request, user=Depends(get_admin_user)): - if request.app.state.config.ENGINE == "automatic1111": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": try: r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", @@ -177,7 +179,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): except Exception: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": try: r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" @@ -194,7 +196,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") request.app.state.config.MODEL = model - if request.app.state.config.ENGINE in ["", "automatic1111"]: + if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", @@ -212,17 +214,17 @@ def set_image_model(request: Request, model: str): def get_image_model(): - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return ( request.app.state.config.MODEL if request.app.state.config.MODEL else "dall-e-2" ) - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return request.app.state.config.MODEL if request.app.state.config.MODEL else "" elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): try: r = requests.get( @@ -285,12 +287,12 @@ async def update_image_config( @router.get("/models") def get_models(request: Request, user=Depends(get_verified_user)): try: - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" @@ -336,8 +338,8 @@ def get_models(request: Request, user=Depends(get_verified_user)): ) ) elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", @@ -433,10 +435,10 @@ async def image_generations( r = None try: - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": headers = {} headers["Authorization"] = ( - f"Bearer {request.app.state.config.OPENAI_API_KEY}" + f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}" ) headers["Content-Type"] = "application/json" @@ -465,7 +467,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -485,7 +487,7 @@ async def image_generations( return images - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -531,8 +533,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): if form_data.model: set_image_model(form_data.model)