From 862c96fcef2b5c822849eda7b3c929b53477033b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 23 Mar 2024 15:38:59 -0700 Subject: [PATCH] feat: comfyui support --- backend/apps/images/main.py | 43 ++++++-- backend/config.py | 1 + src/lib/apis/images/index.ts | 10 +- .../components/chat/Settings/Images.svelte | 98 +++++++++++++++---- 4 files changed, 120 insertions(+), 32 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index e14b0f6a7..12693cf8b 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -26,7 +26,7 @@ import uuid import base64 import json -from config import CACHE_DIR, AUTOMATIC1111_BASE_URL +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") @@ -49,6 +49,8 @@ app.state.MODEL = "" app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 @@ -71,32 +73,43 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -class UrlUpdateForm(BaseModel): - url: str +class EngineUrlUpdateForm(BaseModel): + AUTOMATIC1111_BASE_URL: Optional[str] = None + COMFYUI_BASE_URL: Optional[str] = None @app.get("/url") -async def get_automatic1111_url(user=Depends(get_admin_user)): - return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} +async def get_engine_url(user=Depends(get_admin_user)): + return { + "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + } @app.post("/url/update") -async def update_automatic1111_url( - form_data: UrlUpdateForm, user=Depends(get_admin_user) +async def update_engine_url( + form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) ): - if form_data.url == "": + if form_data.AUTOMATIC1111_BASE_URL == None: app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - url = form_data.url.strip("/") + url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) app.state.AUTOMATIC1111_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + if form_data.COMFYUI_BASE_URL == None: + app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + else: + url = form_data.COMFYUI_BASE_URL.strip("/") + app.state.COMFYUI_BASE_URL = url + return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "status": True, } @@ -186,6 +199,18 @@ def get_models(user=Depends(get_current_user)): {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] + elif app.state.ENGINE == "comfyui": + + r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") + info = r.json() + + return list( + map( + lambda model: {"id": model, "name": model}, + info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], + ) + ) + else: r = requests.get( url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" diff --git a/backend/config.py b/backend/config.py index 9236e8a86..67edd3f4f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -376,3 +376,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models" #################################### AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") +COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 1fb004a3c..aadfafd14 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { return res.OPENAI_API_KEY; }; -export const getAUTOMATIC1111Url = async (token: string = '') => { +export const getImageGenerationEngineUrls = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { @@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => { throw error; } - return res.AUTOMATIC1111_BASE_URL; + return res; }; -export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => { +export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { @@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + ...urls }) }) .then(async (res) => { @@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => throw error; } - return res.AUTOMATIC1111_BASE_URL; + return res; }; export const getImageSize = async (token: string = '') => { diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index 5ba046f19..ee481402e 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -4,14 +4,14 @@ import { createEventDispatcher, onMount, getContext } from 'svelte'; import { config, user } from '$lib/stores'; import { - getAUTOMATIC1111Url, getImageGenerationModels, getDefaultImageGenerationModel, updateDefaultImageGenerationModel, getImageSize, getImageGenerationConfig, updateImageGenerationConfig, - updateAUTOMATIC1111Url, + getImageGenerationEngineUrls, + updateImageGenerationEngineUrls, updateImageSize, getImageSteps, updateImageSteps, @@ -31,6 +31,8 @@ let enableImageGeneration = false; let AUTOMATIC1111_BASE_URL = ''; + let COMFYUI_BASE_URL = ''; + let OPENAI_API_KEY = ''; let selectedModel = ''; @@ -49,24 +51,47 @@ }); }; - const updateAUTOMATIC1111UrlHandler = async () => { - const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch( - (error) => { + const updateUrlHandler = async () => { + if (imageGenerationEngine === 'comfyui') { + const res = await updateImageGenerationEngineUrls(localStorage.token, { + COMFYUI_BASE_URL: COMFYUI_BASE_URL + }).catch((error) => { toast.error(error); + + console.log(error); return null; - } - ); + }); - if (res) { - AUTOMATIC1111_BASE_URL = res; + if (res) { + COMFYUI_BASE_URL = res.COMFYUI_BASE_URL; - await getModels(); + await getModels(); - if (models) { - toast.success($i18n.t('Server connection verified')); + if (models) { + toast.success($i18n.t('Server connection verified')); + } + } else { + ({ COMFYUI_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token)); } } else { - AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + const res = await updateImageGenerationEngineUrls(localStorage.token, { + AUTOMATIC1111_BASE_URL: AUTOMATIC1111_BASE_URL + }).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL; + + await getModels(); + + if (models) { + toast.success($i18n.t('Server connection verified')); + } + } else { + ({ AUTOMATIC1111_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token)); + } } }; const updateImageGeneration = async () => { @@ -101,7 +126,11 @@ imageGenerationEngine = res.engine; enableImageGeneration = res.enabled; } - AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + const URLS = await getImageGenerationEngineUrls(localStorage.token); + + AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL; + COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL; + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); imageSize = await getImageSize(localStorage.token); @@ -154,6 +183,7 @@ }} > + @@ -171,6 +201,9 @@ if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { toast.error($i18n.t('AUTOMATIC1111 Base URL is required.')); enableImageGeneration = false; + } else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') { + toast.error($i18n.t('ComfyUI Base URL is required.')); + enableImageGeneration = false; } else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') { toast.error($i18n.t('OpenAI API Key is required.')); enableImageGeneration = false; @@ -204,12 +237,10 @@ /> + {:else if imageGenerationEngine === 'openai'}
{$i18n.t('OpenAI API Key')}