open-webui/backend/open_webui/routers/images.py
Timothy Jaeryang Baek 7cf4c9c89c
Some checks are pending
Deploy to HuggingFace Spaces / check-secret (push) Waiting to run
Deploy to HuggingFace Spaces / deploy (push) Blocked by required conditions
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-ollama-images (push) Blocked by required conditions
Python CI / Format Backend (3.11) (push) Waiting to run
Frontend Build / Format & Build Frontend (push) Waiting to run
Frontend Build / Frontend Unit Tests (push) Waiting to run
Integration Test / Run Cypress Integration Tests (push) Waiting to run
Integration Test / Run Migration Tests (push) Waiting to run
refac: comfyui
2025-01-16 11:17:37 -08:00

622 lines
22 KiB
Python

import asyncio
import base64
import json
import logging
import mimetypes
import re
import uuid
from pathlib import Path
from typing import Optional
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
router = APIRouter()
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
class OpenAIConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
class Automatic1111ConfigForm(BaseModel):
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
AUTOMATIC1111_SAMPLER: Optional[str]
AUTOMATIC1111_SCHEDULER: Optional[str]
class ComfyUIConfigForm(BaseModel):
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
class ConfigForm(BaseModel):
enabled: bool
engine: str
prompt_generation: bool
openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm
@router.post("/config/update")
async def update_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.prompt_generation
)
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.openai.OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
request.app.state.config.AUTOMATIC1111_API_AUTH = (
form_data.automatic1111.AUTOMATIC1111_API_AUTH
)
request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
else None
)
request.app.state.config.AUTOMATIC1111_SAMPLER = (
form_data.automatic1111.AUTOMATIC1111_SAMPLER
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
else None
)
request.app.state.config.AUTOMATIC1111_SCHEDULER = (
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
else None
)
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
)
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
def get_automatic1111_api_auth(request: Request):
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else:
return True
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth(request)
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else ""
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
options = r.json()
return options["sd_model_checkpoint"]
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class ImageConfigForm(BaseModel):
MODEL: str
IMAGE_SIZE: str
IMAGE_STEPS: int
@router.get("/image/config")
async def get_image_config(request: Request, user=Depends(get_admin_user)):
return {
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@router.post("/image/config/update")
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
if re.match(pattern, form_data.IMAGE_SIZE):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
if form_data.IMAGE_STEPS >= 0:
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
return {
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
info = r.json()
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model":
if node["node_ids"]:
model_node_id = node["node_ids"][0]
break
if model_node_id:
model_list_key = None
print(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
if "_name" in key:
model_list_key = key
break
if model_list_key:
return list(
map(
lambda model: {"id": model, "name": model},
info[workflow[model_node_id]["class_type"]]["input"][
"required"
][model_list_key][0],
)
)
else:
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"][
"ckpt_name"
][0],
)
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth(request)},
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
size: Optional[str] = None
n: int = 1
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
try:
image_id = str(uuid.uuid4())
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else:
image_filename = f"{image_id}.png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
except Exception as e:
log.exception(f"Error saving image: {e}")
return None
def save_url_image(url, headers=None):
image_id = str(uuid.uuid4())
try:
if headers:
r = requests.get(url, headers=headers)
else:
r = requests.get(url)
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
else:
log.error("Url does not point to an image.")
return None
except Exception as e:
log.exception(f"Error saving image: {e}")
return None
@router.post("/generations")
async def image_generations(
request: Request,
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
r = None
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
headers = {}
headers["Authorization"] = (
f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
data = {
"model": (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
form_data = ComfyUIGenerateImageForm(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_generate_image(
request.app.state.config.IMAGE_GENERATION_MODEL,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_filename = save_url_image(image["url"], headers)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(form_data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(form_data.model)
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
if request.app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
if request.app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth(request)},
)
res = r.json()
log.debug(f"res: {res}")
images = []
for image in res["images"]:
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
return images
except Exception as e:
error = e
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))