From 772f5ccd60f1967bbb62afe16bbdacc481b3fec1 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:53:38 -0800 Subject: [PATCH] wip --- backend/open_webui/main.py | 52 ++--- backend/open_webui/routers/images.py | 280 +++++++++++++++------------ backend/open_webui/routers/webui.py | 94 --------- 3 files changed, 180 insertions(+), 246 deletions(-) delete mode 100644 backend/open_webui/routers/webui.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 09913ae01..84af1685a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -697,11 +697,11 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): if not filter: continue - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] + if filter_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[filter_id] else: function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module + app.state.FUNCTIONS[filter_id] = function_module # Check if the function has a file_handler variable if hasattr(function_module, "file_handler"): @@ -828,7 +828,7 @@ async def chat_completion_tools_handler( models, ) tools = get_tools( - webui_app, + app, tool_ids, user, { @@ -1406,7 +1406,7 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) - request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) @@ -1913,11 +1913,11 @@ async def get_all_models(): ] def get_function_module_by_id(function_id): - if function_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[function_id] + if function_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[function_id] else: function_module, _, _ = load_function_module_by_id(function_id) - webui_app.state.FUNCTIONS[function_id] = function_module + app.state.FUNCTIONS[function_id] = function_module for model in models: action_ids = [ @@ -1953,7 +1953,7 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] - model_order_list = webui_app.state.config.MODEL_ORDER_LIST + model_order_list = app.state.config.MODEL_ORDER_LIST if model_order_list: model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} # Sort models by order list priority, with fallback for those not in the list @@ -2229,11 +2229,11 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): if not filter: continue - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] + if filter_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[filter_id] else: function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module + app.state.FUNCTIONS[filter_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(filter_id) @@ -2340,11 +2340,11 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified } ) - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] + if action_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[action_id] else: function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module + app.state.FUNCTIONS[action_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(action_id) @@ -2448,17 +2448,17 @@ async def get_app_config(request: Request): }, "features": { "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_ldap": webui_app.state.config.ENABLE_LDAP, - "enable_api_key": webui_app.state.config.ENABLE_API_KEY, - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_ldap": app.state.config.ENABLE_LDAP, + "enable_api_key": app.state.config.ENABLE_API_KEY, + "enable_signup": app.state.config.ENABLE_SIGNUP, + "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, } @@ -2468,8 +2468,8 @@ async def get_app_config(request: Request): }, **( { - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { "engine": audio_app.state.config.TTS_ENGINE, @@ -2484,7 +2484,7 @@ async def get_app_config(request: Request): "max_size": retrieval_app.state.config.FILE_MAX_SIZE, "max_count": retrieval_app.state.config.FILE_MAX_COUNT, }, - "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + "permissions": {**app.state.config.USER_PERMISSIONS}, } if user is not None else {} @@ -2506,7 +2506,7 @@ async def get_webhook_url(user=Depends(get_admin_user)): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL + app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return {"url": app.state.config.WEBHOOK_URL} diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index f4c12ab64..0deded03e 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -9,6 +9,18 @@ from pathlib import Path from typing import Optional import requests + + +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS + +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, @@ -16,48 +28,36 @@ from open_webui.utils.images.comfyui import ( ) -from open_webui.config import CACHE_DIR -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS - -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from open_webui.utils.auth import get_admin_user, get_verified_user - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) + +router = APIRouter() -@app.get("/config") +@router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLED, + "engine": request.app.state.config.ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } @@ -89,133 +89,150 @@ class ConfigForm(BaseModel): comfyui: ComfyUIConfigForm -@app.post("/config/update") -async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled +@router.post("/config/update") +async def update_config( + request: Request, form_data: ConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENGINE = form_data.engine + request.app.state.config.ENABLED = form_data.enabled - app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL + request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY - app.state.config.AUTOMATIC1111_BASE_URL = ( + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) - app.state.config.AUTOMATIC1111_API_AUTH = ( + request.app.state.config.AUTOMATIC1111_API_AUTH = ( form_data.automatic1111.AUTOMATIC1111_API_AUTH ) - app.state.config.AUTOMATIC1111_CFG_SCALE = ( + request.app.state.config.AUTOMATIC1111_CFG_SCALE = ( float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE else None ) - app.state.config.AUTOMATIC1111_SAMPLER = ( + request.app.state.config.AUTOMATIC1111_SAMPLER = ( form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER else None ) - app.state.config.AUTOMATIC1111_SCHEDULER = ( + request.app.state.config.AUTOMATIC1111_SCHEDULER = ( form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER else None ) - app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") - app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW - app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + request.app.state.config.COMFYUI_BASE_URL = ( + form_data.comfyui.COMFYUI_BASE_URL.strip("/") + ) + request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + request.app.state.config.COMFYUI_WORKFLOW_NODES = ( + form_data.comfyui.COMFYUI_WORKFLOW_NODES + ) return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLED, + "engine": request.app.state.config.ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } -def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH is None: +def get_automatic1111_api_auth(request: Request): + if request.app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: - auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") + auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( + "utf-8" + ) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config/url/verify") -async def verify_url(user=Depends(get_admin_user)): - if app.state.config.ENGINE == "automatic1111": +@router.get("/config/url/verify") +async def verify_url(request: Request, user=Depends(get_admin_user)): + if request.app.state.config.ENGINE == "automatic1111": try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": try: - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True -def set_image_model(model: str): +def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") - app.state.config.MODEL = model - if app.state.config.ENGINE in ["", "automatic1111"]: + request.app.state.config.MODEL = model + if request.app.state.config.ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": api_auth}, ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options, headers={"authorization": api_auth}, ) - return app.state.config.MODEL + return request.app.state.config.MODEL def get_image_model(): - if app.state.config.ENGINE == "openai": - return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - elif app.state.config.ENGINE == "comfyui": - return app.state.config.MODEL if app.state.config.MODEL else "" - elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": + if request.app.state.config.ENGINE == "openai": + return ( + request.app.state.config.MODEL + if request.app.state.config.MODEL + else "dall-e-2" + ) + elif request.app.state.config.ENGINE == "comfyui": + return request.app.state.config.MODEL if request.app.state.config.MODEL else "" + elif ( + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" + ): try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth()}, ) options = r.json() return options["sd_model_checkpoint"] except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -225,23 +242,25 @@ class ImageConfigForm(BaseModel): IMAGE_STEPS: int -@app.get("/image/config") -async def get_image_config(user=Depends(get_admin_user)): +@router.get("/image/config") +async def get_image_config(request: Request, user=Depends(get_admin_user)): return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.post("/image/config/update") -async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): +@router.post("/image/config/update") +async def update_image_config( + request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) +): - set_image_model(form_data.MODEL) + set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" if re.match(pattern, form_data.IMAGE_SIZE): - app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE + request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, @@ -249,7 +268,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) if form_data.IMAGE_STEPS >= 0: - app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS + request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, @@ -257,29 +276,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.get("/models") -def get_models(user=Depends(get_verified_user)): +@router.get("/models") +def get_models(request: Request, user=Depends(get_verified_user)): try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": # TODO - get models from comfyui - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) info = r.json() - workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) model_node_id = None - for node in app.state.config.COMFYUI_WORKFLOW_NODES: + for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: if node["type"] == "model": if node["node_ids"]: model_node_id = node["node_ids"][0] @@ -315,10 +336,11 @@ def get_models(user=Depends(get_verified_user)): ) ) elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" ): r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, ) models = r.json() @@ -329,7 +351,7 @@ def get_models(user=Depends(get_verified_user)): ) ) except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -401,18 +423,21 @@ def save_url_image(url): return None -@app.post("/generations") +@router.post("/generations") async def image_generations( + request: Request, form_data: GenerateImageForm, user=Depends(get_verified_user), ): - width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) r = None try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: @@ -423,14 +448,16 @@ async def image_generations( data = { "model": ( - app.state.config.MODEL - if app.state.config.MODEL != "" + request.app.state.config.MODEL + if request.app.state.config.MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size if form_data.size else app.state.config.IMAGE_SIZE + form_data.size + if form_data.size + else request.app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } @@ -438,7 +465,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -458,7 +485,7 @@ async def image_generations( return images - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -466,8 +493,8 @@ async def image_generations( "n": form_data.n, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -476,18 +503,18 @@ async def image_generations( **{ "workflow": ComfyUIWorkflow( **{ - "workflow": app.state.config.COMFYUI_WORKFLOW, - "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + "workflow": request.app.state.config.COMFYUI_WORKFLOW, + "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, } ) res = await comfyui_generate_image( - app.state.config.MODEL, + request.app.state.config.MODEL, form_data, user.id, - app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -504,7 +531,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" ): if form_data.model: set_image_model(form_data.model) @@ -516,25 +544,25 @@ async def image_generations( "height": height, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.AUTOMATIC1111_CFG_SCALE: - data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE + if request.app.state.config.AUTOMATIC1111_CFG_SCALE: + data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE - if app.state.config.AUTOMATIC1111_SAMPLER: - data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER + if request.app.state.config.AUTOMATIC1111_SAMPLER: + data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER - if app.state.config.AUTOMATIC1111_SCHEDULER: - data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER + if request.app.state.config.AUTOMATIC1111_SCHEDULER: + data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, headers={"authorization": get_automatic1111_api_auth()}, ) diff --git a/backend/open_webui/routers/webui.py b/backend/open_webui/routers/webui.py deleted file mode 100644 index 058d4c365..000000000 --- a/backend/open_webui/routers/webui.py +++ /dev/null @@ -1,94 +0,0 @@ -import inspect -import json -import logging -import time -from typing import AsyncGenerator, Generator, Iterator - -from open_webui.socket.main import get_event_call, get_event_emitter -from open_webui.models.functions import Functions -from open_webui.models.models import Models -from open_webui.routers import ( - auths, - chats, - folders, - configs, - groups, - files, - functions, - memories, - models, - knowledge, - prompts, - evaluations, - tools, - users, - utils, -) -from open_webui.utils.plugin import load_function_module_by_id -from open_webui.config import ( - ADMIN_EMAIL, - CORS_ALLOW_ORIGIN, - DEFAULT_MODELS, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_USER_ROLE, - MODEL_ORDER_LIST, - ENABLE_COMMUNITY_SHARING, - ENABLE_LOGIN_FORM, - ENABLE_MESSAGE_RATING, - ENABLE_SIGNUP, - ENABLE_API_KEY, - ENABLE_EVALUATION_ARENA_MODELS, - EVALUATION_ARENA_MODELS, - DEFAULT_ARENA_MODEL, - JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MANAGEMENT, - OAUTH_ROLES_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, - SHOW_ADMIN_DETAILS, - USER_PERMISSIONS, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_BANNERS, - ENABLE_LDAP, - LDAP_SERVER_LABEL, - LDAP_SERVER_HOST, - LDAP_SERVER_PORT, - LDAP_ATTRIBUTE_FOR_USERNAME, - LDAP_SEARCH_FILTERS, - LDAP_SEARCH_BASE, - LDAP_APP_DN, - LDAP_APP_PASSWORD, - LDAP_USE_TLS, - LDAP_CA_CERT_FILE, - LDAP_CIPHERS, - AppConfig, -) -from open_webui.env import ( - ENV, - SRC_LOG_LEVELS, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, -) -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from open_webui.utils.misc import ( - openai_chat_chunk_message_template, - openai_chat_completion_message_template, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - - -from open_webui.utils.tools import get_tools - - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MAIN"])