From 95057d2368b0edb79a3aeed7c5641aeae519a312 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 21 Aug 2024 00:35:42 +0200 Subject: [PATCH] refac: image gen --- backend/apps/images/main.py | 367 ++++++++---------- backend/apps/images/utils/comfyui.py | 110 ------ backend/config.py | 119 +++++- .../components/admin/Settings/Images.svelte | 78 ++-- 4 files changed, 314 insertions(+), 360 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index d5bf45298..4db887f61 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -1,5 +1,3 @@ -import re -import requests from fastapi import ( FastAPI, Request, @@ -7,14 +5,6 @@ from fastapi import ( HTTPException, ) from fastapi.middleware.cors import CORSMiddleware - -from constants import ERROR_MESSAGES -from utils.utils import ( - get_verified_user, - get_admin_user, -) - -from apps.images.utils.comfyui import ComfyUIGenerateImageForm, comfyui_generate_image from typing import Optional from pydantic import BaseModel from pathlib import Path @@ -23,7 +13,21 @@ import uuid import base64 import json import logging +import re +import requests +from utils.utils import ( + get_verified_user, + get_admin_user, +) + +from apps.images.utils.comfyui import ( + ComfyUIWorkflow, + ComfyUIGenerateImageForm, + comfyui_generate_image, +) + +from constants import ERROR_MESSAGES from config import ( SRC_LOG_LEVELS, CACHE_DIR, @@ -76,6 +80,89 @@ app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_STEPS = IMAGE_STEPS +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, + } + + +class OpenAIConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + + +class Automatic1111ConfigForm(BaseModel): + AUTOMATIC1111_BASE_URL: str + AUTOMATIC1111_API_AUTH: str + + +class ComfyUIConfigForm(BaseModel): + COMFYUI_BASE_URL: str + COMFYUI_WORKFLOW: str + COMFYUI_WORKFLOW_NODES: list[dict] + + +class ConfigForm(BaseModel): + enabled: bool + engine: str + openai: OpenAIConfigForm + automatic1111: Automatic1111ConfigForm + comfyui: ComfyUIConfigForm + + +@app.post("/config/update") +async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): + app.state.config.ENGINE = form_data.engine + app.state.config.ENABLED = form_data.enabled + + app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL + app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + + app.state.config.AUTOMATIC1111_BASE_URL = ( + form_data.automatic1111.AUTOMATIC1111_BASE_URL + ) + app.state.config.AUTOMATIC1111_API_AUTH = ( + form_data.automatic1111.AUTOMATIC1111_API_AUTH + ) + + app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL + app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + + return { + "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, + } + + def get_automatic1111_api_auth(): if app.state.config.AUTOMATIC1111_API_AUTH is None: return "" @@ -86,166 +173,85 @@ def get_automatic1111_api_auth(): return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config") -async def get_config(request: Request, user=Depends(get_admin_user)): - return { - "engine": app.state.config.ENGINE, - "enabled": app.state.config.ENABLED, - } +def set_image_model(model: str): + app.state.config.MODEL = model + if app.state.config.ENGINE in ["", "automatic1111"]: + api_auth = get_automatic1111_api_auth() + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": api_auth}, + ) + options = r.json() + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, + headers={"authorization": api_auth}, + ) + return app.state.config.MODEL -class ConfigUpdateForm(BaseModel): - engine: str - enabled: bool - - -@app.post("/config/update") -async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled - return { - "engine": app.state.config.ENGINE, - "enabled": app.state.config.ENABLED, - } - - -class EngineUrlUpdateForm(BaseModel): - AUTOMATIC1111_BASE_URL: Optional[str] = None - AUTOMATIC1111_API_AUTH: Optional[str] = None - COMFYUI_BASE_URL: Optional[str] = None - - -@app.get("/url") -async def get_engine_url(user=Depends(get_admin_user)): - return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - } - - -@app.post("/url/update") -async def update_engine_url( - form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) -): - if form_data.AUTOMATIC1111_BASE_URL is None: - app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL - else: - url = form_data.AUTOMATIC1111_BASE_URL.strip("/") +def get_image_model(): + if app.state.config.ENGINE == "openai": + return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" + elif app.state.config.ENGINE == "comfyui": + return app.state.config.MODEL if app.state.config.MODEL else "" + elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": try: - r = requests.head(url) - r.raise_for_status() - app.state.config.AUTOMATIC1111_BASE_URL = url + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth()}, + ) + options = r.json() + return options["sd_model_checkpoint"] except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.COMFYUI_BASE_URL is None: - app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL - else: - url = form_data.COMFYUI_BASE_URL.strip("/") - - try: - r = requests.head(url) - r.raise_for_status() - app.state.config.COMFYUI_BASE_URL = url - except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.AUTOMATIC1111_API_AUTH is None: - app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH - else: - app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH - - return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "status": True, - } + app.state.config.ENABLED = False + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) -class OpenAIConfigUpdateForm(BaseModel): - url: str - key: str - - -@app.get("/openai/config") -async def get_openai_config(user=Depends(get_admin_user)): - return { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - } - - -@app.post("/openai/config/update") -async def update_openai_config( - form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) -): - if form_data.key == "": - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - - app.state.config.OPENAI_API_BASE_URL = form_data.url - app.state.config.OPENAI_API_KEY = form_data.key - - return { - "status": True, - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - } - - -class ImageSizeUpdateForm(BaseModel): +class ImageConfigForm(BaseModel): + model: str size: str + steps: int -@app.get("/size") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} +@app.get("/image/config") +async def get_image_config(user=Depends(get_admin_user)): + return { + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + } -@app.post("/size/update") -async def update_image_size( - form_data: ImageSizeUpdateForm, user=Depends(get_admin_user) -): - pattern = r"^\d+x\d+$" # Regular expression pattern +@app.post("/image/config/update") +async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): + app.state.config.MODEL = form_data.model + + pattern = r"^\d+x\d+$" if re.match(pattern, form_data.size): app.state.config.IMAGE_SIZE = form_data.size - return { - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "status": True, - } else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), ) - -class ImageStepsUpdateForm(BaseModel): - steps: int - - -@app.get("/steps") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} - - -@app.post("/steps/update") -async def update_image_size( - form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) -): if form_data.steps >= 0: app.state.config.IMAGE_STEPS = form_data.steps - return { - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, - "status": True, - } else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."), ) + return { + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + } + @app.get("/models") def get_models(user=Depends(get_verified_user)): @@ -256,7 +262,7 @@ def get_models(user=Depends(get_verified_user)): {"id": "dall-e-3", "name": "DALLĀ·E 3"}, ] elif app.state.config.ENGINE == "comfyui": - + # TODO - get models from comfyui r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() @@ -266,8 +272,9 @@ def get_models(user=Depends(get_verified_user)): info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], ) ) - - else: + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): r = requests.get( url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, @@ -284,69 +291,11 @@ def get_models(user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) -@app.get("/models/default") -async def get_default_model(user=Depends(get_admin_user)): - try: - if app.state.config.ENGINE == "openai": - return { - "model": ( - app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - ) - } - elif app.state.config.ENGINE == "comfyui": - return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} - else: - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, - ) - options = r.json() - return {"model": options["sd_model_checkpoint"]} - except Exception as e: - app.state.config.ENABLED = False - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) - - -class UpdateModelForm(BaseModel): - model: str - - -def set_model_handler(model: str): - if app.state.config.ENGINE in ["openai", "comfyui"]: - app.state.config.MODEL = model - return app.state.config.MODEL - else: - api_auth = get_automatic1111_api_auth() - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": api_auth}, - ) - options = r.json() - - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - json=options, - headers={"authorization": api_auth}, - ) - - return options - - -@app.post("/models/default/update") -def update_default_model( - form_data: UpdateModelForm, - user=Depends(get_verified_user), -): - return set_model_handler(form_data.model) - - class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str - n: int = 1 size: Optional[str] = None + n: int = 1 negative_prompt: Optional[str] = None @@ -497,9 +446,11 @@ async def image_generations( log.debug(f"images: {images}") return images - else: + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): if form_data.model: - set_model_handler(form_data.model) + set_image_model(form_data.model) data = { "prompt": form_data.prompt, @@ -521,7 +472,6 @@ async def image_generations( ) res = r.json() - log.debug(f"res: {res}") images = [] @@ -538,7 +488,6 @@ async def image_generations( except Exception as e: error = e - if r != None: data = r.json() if "error" in data: diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index d61a8f553..9117d7abf 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -15,116 +15,6 @@ from pydantic import BaseModel from typing import Optional -COMFYUI_DEFAULT_WORKFLOW = """ -{ - "3": { - "inputs": { - "seed": 0, - "steps": 20, - "cfg": 8, - "sampler_name": "euler", - "scheduler": "normal", - "denoise": 1, - "model": [ - "4", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - }, - "4": { - "inputs": { - "ckpt_name": "model.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "5": { - "inputs": { - "width": 512, - "height": 512, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, - "6": { - "inputs": { - "text": "Prompt", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "7": { - "inputs": { - "text": "", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "8": { - "inputs": { - "samples": [ - "3", - 0 - ], - "vae": [ - "4", - 2 - ] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - } -} -""" - def queue_prompt(prompt, client_id, base_url): log.info("queue_prompt") diff --git a/backend/config.py b/backend/config.py index 276f1f636..0ffacca1b 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1342,10 +1342,127 @@ COMFYUI_BASE_URL = PersistentConfig( os.getenv("COMFYUI_BASE_URL", ""), ) +COMFYUI_DEFAULT_WORKFLOW = """ +{ + "3": { + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "model.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} +""" + + COMFYUI_WORKFLOW = PersistentConfig( "COMFYUI_WORKFLOW", "image_generation.comfyui.workflow", - os.getenv("COMFYUI_WORKFLOW", ""), + os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), +) + +COMFYUI_WORKFLOW_NODES = PersistentConfig( + "COMFYUI_WORKFLOW", + "image_generation.comfyui.nodes", + [], ) IMAGES_OPENAI_API_BASE_URL = PersistentConfig( diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 9838792f2..12451ef7d 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -20,13 +20,14 @@ } from '$lib/apis/images'; import { getBackendConfig } from '$lib/apis'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; + import Switch from '$lib/components/common/Switch.svelte'; const dispatch = createEventDispatcher(); const i18n = getContext('i18n'); let loading = false; - let imageGenerationEngine = ''; + let imageGenerationEngine = 'openai'; let enableImageGeneration = false; let AUTOMATIC1111_BASE_URL = ''; @@ -128,7 +129,7 @@ }); if (res) { - imageGenerationEngine = res.engine; + imageGenerationEngine = res.engine ?? 'automatic1111'; enableImageGeneration = res.enabled; } const URLS = await getImageGenerationEngineUrls(localStorage.token); @@ -180,6 +181,38 @@
{$i18n.t('Image Settings')}
+
+
+
+ {$i18n.t('Image Generation (Experimental)')} +
+ +
+ { + const enabled = e.detail; + + if (enabled) { + if (imageGenerationEngine === 'automatic1111' && AUTOMATIC1111_BASE_URL === '') { + toast.error($i18n.t('AUTOMATIC1111 Base URL is required.')); + enableImageGeneration = false; + } else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') { + toast.error($i18n.t('ComfyUI Base URL is required.')); + enableImageGeneration = false; + } else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') { + toast.error($i18n.t('OpenAI API Key is required.')); + enableImageGeneration = false; + } + } + + updateImageGeneration(); + }} + /> +
+
+
+
{$i18n.t('Image Generation Engine')}
@@ -191,51 +224,16 @@ await updateImageGeneration(); }} > - + - +
- -
-
-
- {$i18n.t('Image Generation (Experimental)')} -
- - -
-

- {#if imageGenerationEngine === ''} + {#if imageGenerationEngine === 'automatic1111'}
{$i18n.t('AUTOMATIC1111 Base URL')}