import asyncio import base64 import json import logging import mimetypes import re import uuid from pathlib import Path from typing import Optional import requests from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) router = APIRouter() @router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } class OpenAIConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str class Automatic1111ConfigForm(BaseModel): AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_API_AUTH: str AUTOMATIC1111_CFG_SCALE: Optional[str | float | int] AUTOMATIC1111_SAMPLER: Optional[str] AUTOMATIC1111_SCHEDULER: Optional[str] class ComfyUIConfigForm(BaseModel): COMFYUI_BASE_URL: str COMFYUI_API_KEY: str COMFYUI_WORKFLOW: str COMFYUI_WORKFLOW_NODES: list[dict] class ConfigForm(BaseModel): enabled: bool engine: str prompt_generation: bool openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm @router.post("/config/update") async def update_config( request: Request, form_data: ConfigForm, user=Depends(get_admin_user) ): request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( form_data.prompt_generation ) request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( form_data.openai.OPENAI_API_BASE_URL ) request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) request.app.state.config.AUTOMATIC1111_API_AUTH = ( form_data.automatic1111.AUTOMATIC1111_API_AUTH ) request.app.state.config.AUTOMATIC1111_CFG_SCALE = ( float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE else None ) request.app.state.config.AUTOMATIC1111_SAMPLER = ( form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER else None ) request.app.state.config.AUTOMATIC1111_SCHEDULER = ( form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER else None ) request.app.state.config.COMFYUI_BASE_URL = ( form_data.comfyui.COMFYUI_BASE_URL.strip("/") ) request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW request.app.state.config.COMFYUI_WORKFLOW_NODES = ( form_data.comfyui.COMFYUI_WORKFLOW_NODES ) return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } def get_automatic1111_api_auth(request: Request): if request.app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( "utf-8" ) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") return f"Basic {auth1111_base64_encoded_string}" @router.get("/config/url/verify") async def verify_url(request: Request, user=Depends(get_admin_user)): if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": try: r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": try: r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" ) r.raise_for_status() return True except Exception: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") request.app.state.config.IMAGE_GENERATION_MODEL = model if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth(request) r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": api_auth}, ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options, headers={"authorization": api_auth}, ) return request.app.state.config.IMAGE_GENERATION_MODEL def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return ( request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL else "dall-e-2" ) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return ( request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL else "" ) elif ( request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): try: r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth(request)}, ) options = r.json() return options["sd_model_checkpoint"] except Exception as e: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) class ImageConfigForm(BaseModel): MODEL: str IMAGE_SIZE: str IMAGE_STEPS: int @router.get("/image/config") async def get_image_config(request: Request, user=Depends(get_admin_user)): return { "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } @router.post("/image/config/update") async def update_image_config( request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) ): set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" if re.match(pattern, form_data.IMAGE_SIZE): request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), ) if form_data.IMAGE_STEPS >= 0: request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."), ) return { "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } @router.get("/models") def get_models(request: Request, user=Depends(get_verified_user)): try: if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui headers = { "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" } r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info", headers=headers, ) info = r.json() workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) model_node_id = None for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: if node["type"] == "model": if node["node_ids"]: model_node_id = node["node_ids"][0] break if model_node_id: model_list_key = None print(workflow[model_node_id]["class_type"]) for key in info[workflow[model_node_id]["class_type"]]["input"][ "required" ]: if "_name" in key: model_list_key = key break if model_list_key: return list( map( lambda model: {"id": model, "name": model}, info[workflow[model_node_id]["class_type"]]["input"][ "required" ][model_list_key][0], ) ) else: return list( map( lambda model: {"id": model, "name": model}, info["CheckpointLoaderSimple"]["input"]["required"][ "ckpt_name" ][0], ) ) elif ( request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth(request)}, ) models = r.json() return list( map( lambda model: {"id": model["title"], "name": model["model_name"]}, models, ) ) except Exception as e: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str size: Optional[str] = None n: int = 1 negative_prompt: Optional[str] = None def save_b64_image(b64_str): try: image_id = str(uuid.uuid4()) if "," in b64_str: header, encoded = b64_str.split(",", 1) mime_type = header.split(";")[0] img_data = base64.b64decode(encoded) image_format = mimetypes.guess_extension(mime_type) image_filename = f"{image_id}{image_format}" file_path = IMAGE_CACHE_DIR / f"{image_filename}" with open(file_path, "wb") as f: f.write(img_data) return image_filename else: image_filename = f"{image_id}.png" file_path = IMAGE_CACHE_DIR.joinpath(image_filename) img_data = base64.b64decode(b64_str) # Write the image data to a file with open(file_path, "wb") as f: f.write(img_data) return image_filename except Exception as e: log.exception(f"Error saving image: {e}") return None def save_url_image(url): image_id = str(uuid.uuid4()) try: r = requests.get(url) r.raise_for_status() if r.headers["content-type"].split("/")[0] == "image": mime_type = r.headers["content-type"] image_format = mimetypes.guess_extension(mime_type) if not image_format: raise ValueError("Could not determine image type from MIME type") image_filename = f"{image_id}{image_format}" file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}") with open(file_path, "wb") as image_file: for chunk in r.iter_content(chunk_size=8192): image_file.write(chunk) return image_filename else: log.error("Url does not point to an image.") return None except Exception as e: log.exception(f"Error saving image: {e}") return None @router.post("/generations") async def image_generations( request: Request, form_data: GenerateImageForm, user=Depends(get_verified_user), ): width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) r = None try: if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": headers = {} headers["Authorization"] = ( f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}" ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: headers["X-OpenWebUI-User-Name"] = user.name headers["X-OpenWebUI-User-Id"] = user.id headers["X-OpenWebUI-User-Email"] = user.email headers["X-OpenWebUI-User-Role"] = user.role data = { "model": ( request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, "n": form_data.n, "size": ( form_data.size if form_data.size else request.app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) r.raise_for_status() res = r.json() images = [] for image in res["data"]: image_filename = save_b64_image(image["b64_json"]) images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump(data, f) return images elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, "height": height, "n": form_data.n, } if request.app.state.config.IMAGE_STEPS is not None: data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt form_data = ComfyUIGenerateImageForm( **{ "workflow": ComfyUIWorkflow( **{ "workflow": request.app.state.config.COMFYUI_WORKFLOW, "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, } ) res = await comfyui_generate_image( request.app.state.config.IMAGE_GENERATION_MODEL, form_data, user.id, request.app.state.config.COMFYUI_BASE_URL, request.app.state.config.COMFYUI_API_KEY, ) log.debug(f"res: {res}") images = [] for image in res["data"]: image_filename = save_url_image(image["url"]) images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump(form_data.model_dump(exclude_none=True), f) log.debug(f"images: {images}") return images elif ( request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): if form_data.model: set_image_model(form_data.model) data = { "prompt": form_data.prompt, "batch_size": form_data.n, "width": width, "height": height, } if request.app.state.config.IMAGE_STEPS is not None: data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt if request.app.state.config.AUTOMATIC1111_CFG_SCALE: data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE if request.app.state.config.AUTOMATIC1111_SAMPLER: data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER if request.app.state.config.AUTOMATIC1111_SCHEDULER: data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, headers={"authorization": get_automatic1111_api_auth(request)}, ) res = r.json() log.debug(f"res: {res}") images = [] for image in res["images"]: image_filename = save_b64_image(image) images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump({**data, "info": res["info"]}, f) return images except Exception as e: error = e if r != None: data = r.json() if "error" in data: error = data["error"]["message"] raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))