diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 12693cf8b..b829b0499 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -18,6 +18,8 @@ from utils.utils import ( get_current_user, get_admin_user, ) + +from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel @@ -105,7 +107,12 @@ async def update_engine_url( app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL else: url = form_data.COMFYUI_BASE_URL.strip("/") - app.state.COMFYUI_BASE_URL = url + + try: + r = requests.head(url) + app.state.COMFYUI_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, @@ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)): try: if app.state.ENGINE == "openai": return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + elif app.state.ENGINE == "comfyui": + return {"model": app.state.MODEL if app.state.MODEL else ""} else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -246,10 +255,12 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE == "openai": app.state.MODEL = model return app.state.MODEL + if app.state.ENGINE == "comfyui": + app.state.MODEL = model + return app.state.MODEL else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -297,12 +308,31 @@ def save_b64_image(b64_str): return None +def save_url_image(url): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + r = requests.get(url) + r.raise_for_status() + + with open(file_path, "wb") as image_file: + image_file.write(r.content) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, user=Depends(get_current_user), ): + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + r = None try: if app.state.ENGINE == "openai": @@ -340,12 +370,47 @@ def generate_image( return images + elif app.state.ENGINE == "comfyui": + + data = { + "prompt": form_data.prompt, + "width": width, + "height": height, + "n": form_data.n, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + data = ImageGenerationPayload(**data) + + res = comfyui_generate_image( + app.state.MODEL, + data, + user.id, + app.state.COMFYUI_BASE_URL, + ) + print(res) + + images = [] + + for image in res["data"]: + image_id = save_url_image(image["url"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data.model_dump(exclude_none=True), f) + + print(images) + return images else: if form_data.model: set_model_handler(form_data.model) - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) - data = { "prompt": form_data.prompt, "batch_size": form_data.n, diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py new file mode 100644 index 000000000..6a9fef353 --- /dev/null +++ b/backend/apps/images/utils/comfyui.py @@ -0,0 +1,228 @@ +import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse +import random + +from pydantic import BaseModel + +from typing import Optional + +COMFYUI_DEFAULT_PROMPT = """ +{ + "3": { + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "model.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "Negative Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} +""" + + +def queue_prompt(prompt, client_id, base_url): + print("queue_prompt") + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode("utf-8") + req = urllib.request.Request(f"{base_url}/prompt", data=data) + return json.loads(urllib.request.urlopen(req).read()) + + +def get_image(filename, subfolder, folder_type, base_url): + print("get_image") + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: + return response.read() + + +def get_image_url(filename, subfolder, folder_type, base_url): + print("get_image") + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + return f"{base_url}/view?{url_values}" + + +def get_history(prompt_id, base_url): + print("get_history") + with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: + return json.loads(response.read()) + + +def get_images(ws, prompt, client_id, base_url): + prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] + output_images = [] + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "executing": + data = message["data"] + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue # previews are binary data + + history = get_history(prompt_id, base_url)[prompt_id] + for o in history["outputs"]: + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if "images" in node_output: + for image in node_output["images"]: + url = get_image_url( + image["filename"], image["subfolder"], image["type"], base_url + ) + output_images.append({"url": url}) + return {"data": output_images} + + +class ImageGenerationPayload(BaseModel): + prompt: str + negative_prompt: Optional[str] = "" + steps: Optional[int] = None + seed: Optional[int] = None + width: int + height: int + n: int = 1 + + +def comfyui_generate_image( + model: str, payload: ImageGenerationPayload, client_id, base_url +): + host = base_url.replace("http://", "").replace("https://", "") + + comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) + + comfyui_prompt["4"]["inputs"]["ckpt_name"] = model + comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n + comfyui_prompt["5"]["inputs"]["width"] = payload.width + comfyui_prompt["5"]["inputs"]["height"] = payload.height + + # set the text prompt for our positive CLIPTextEncode + comfyui_prompt["6"]["inputs"]["text"] = payload.prompt + comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt + + if payload.steps: + comfyui_prompt["3"]["inputs"]["steps"] = payload.steps + + comfyui_prompt["3"]["inputs"]["seed"] = ( + payload.seed if payload.seed else random.randint(0, 18446744073709551614) + ) + + try: + ws = websocket.WebSocket() + ws.connect(f"ws://{host}/ws?clientId={client_id}") + print("WebSocket connection established.") + except Exception as e: + print(f"Failed to connect to WebSocket server: {e}") + return None + + try: + images = get_images(ws, comfyui_prompt, client_id, base_url) + except Exception as e: + print(f"Error while receiving images: {e}") + images = None + + ws.close() + + return images diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index ee481402e..7282c184a 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -323,6 +323,7 @@ class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" bind:value={selectedModel} placeholder={$i18n.t('Select a model')} + required > {#if !selectedModel} diff --git a/src/lib/components/common/ImagePreview.svelte b/src/lib/components/common/ImagePreview.svelte index cf69327fa..badabebda 100644 --- a/src/lib/components/common/ImagePreview.svelte +++ b/src/lib/components/common/ImagePreview.svelte @@ -2,6 +2,22 @@ export let show = false; export let src = ''; export let alt = ''; + + const downloadImage = (url, filename) => { + fetch(url) + .then((response) => response.blob()) + .then((blob) => { + const objectUrl = window.URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = objectUrl; + link.download = filename; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + window.URL.revokeObjectURL(objectUrl); + }) + .catch((error) => console.error('Error downloading image:', error)); + }; {#if show} @@ -35,10 +51,7 @@