diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py
index 12693cf8b..b829b0499 100644
--- a/backend/apps/images/main.py
+++ b/backend/apps/images/main.py
@@ -18,6 +18,8 @@ from utils.utils import (
get_current_user,
get_admin_user,
)
+
+from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
@@ -105,7 +107,12 @@ async def update_engine_url(
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
- app.state.COMFYUI_BASE_URL = url
+
+ try:
+ r = requests.head(url)
+ app.state.COMFYUI_BASE_URL = url
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@@ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+ elif app.state.ENGINE == "comfyui":
+ return {"model": app.state.MODEL if app.state.MODEL else ""}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@@ -246,10 +255,12 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
-
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
+ if app.state.ENGINE == "comfyui":
+ app.state.MODEL = model
+ return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@@ -297,12 +308,31 @@ def save_b64_image(b64_str):
return None
+def save_url_image(url):
+ image_id = str(uuid.uuid4())
+ file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+ try:
+ r = requests.get(url)
+ r.raise_for_status()
+
+ with open(file_path, "wb") as image_file:
+ image_file.write(r.content)
+
+ return image_id
+ except Exception as e:
+ print(f"Error saving image: {e}")
+ return None
+
+
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
+ width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
r = None
try:
if app.state.ENGINE == "openai":
@@ -340,12 +370,47 @@ def generate_image(
return images
+ elif app.state.ENGINE == "comfyui":
+
+ data = {
+ "prompt": form_data.prompt,
+ "width": width,
+ "height": height,
+ "n": form_data.n,
+ }
+
+ if app.state.IMAGE_STEPS != None:
+ data["steps"] = app.state.IMAGE_STEPS
+
+ if form_data.negative_prompt != None:
+ data["negative_prompt"] = form_data.negative_prompt
+
+ data = ImageGenerationPayload(**data)
+
+ res = comfyui_generate_image(
+ app.state.MODEL,
+ data,
+ user.id,
+ app.state.COMFYUI_BASE_URL,
+ )
+ print(res)
+
+ images = []
+
+ for image in res["data"]:
+ image_id = save_url_image(image["url"])
+ images.append({"url": f"/cache/image/generations/{image_id}.png"})
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+ with open(file_body_path, "w") as f:
+ json.dump(data.model_dump(exclude_none=True), f)
+
+ print(images)
+ return images
else:
if form_data.model:
set_model_handler(form_data.model)
- width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
-
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py
new file mode 100644
index 000000000..6a9fef353
--- /dev/null
+++ b/backend/apps/images/utils/comfyui.py
@@ -0,0 +1,228 @@
+import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+import uuid
+import json
+import urllib.request
+import urllib.parse
+import random
+
+from pydantic import BaseModel
+
+from typing import Optional
+
+COMFYUI_DEFAULT_PROMPT = """
+{
+ "3": {
+ "inputs": {
+ "seed": 0,
+ "steps": 20,
+ "cfg": 8,
+ "sampler_name": "euler",
+ "scheduler": "normal",
+ "denoise": 1,
+ "model": [
+ "4",
+ 0
+ ],
+ "positive": [
+ "6",
+ 0
+ ],
+ "negative": [
+ "7",
+ 0
+ ],
+ "latent_image": [
+ "5",
+ 0
+ ]
+ },
+ "class_type": "KSampler",
+ "_meta": {
+ "title": "KSampler"
+ }
+ },
+ "4": {
+ "inputs": {
+ "ckpt_name": "model.safetensors"
+ },
+ "class_type": "CheckpointLoaderSimple",
+ "_meta": {
+ "title": "Load Checkpoint"
+ }
+ },
+ "5": {
+ "inputs": {
+ "width": 512,
+ "height": 512,
+ "batch_size": 1
+ },
+ "class_type": "EmptyLatentImage",
+ "_meta": {
+ "title": "Empty Latent Image"
+ }
+ },
+ "6": {
+ "inputs": {
+ "text": "Prompt",
+ "clip": [
+ "4",
+ 1
+ ]
+ },
+ "class_type": "CLIPTextEncode",
+ "_meta": {
+ "title": "CLIP Text Encode (Prompt)"
+ }
+ },
+ "7": {
+ "inputs": {
+ "text": "Negative Prompt",
+ "clip": [
+ "4",
+ 1
+ ]
+ },
+ "class_type": "CLIPTextEncode",
+ "_meta": {
+ "title": "CLIP Text Encode (Prompt)"
+ }
+ },
+ "8": {
+ "inputs": {
+ "samples": [
+ "3",
+ 0
+ ],
+ "vae": [
+ "4",
+ 2
+ ]
+ },
+ "class_type": "VAEDecode",
+ "_meta": {
+ "title": "VAE Decode"
+ }
+ },
+ "9": {
+ "inputs": {
+ "filename_prefix": "ComfyUI",
+ "images": [
+ "8",
+ 0
+ ]
+ },
+ "class_type": "SaveImage",
+ "_meta": {
+ "title": "Save Image"
+ }
+ }
+}
+"""
+
+
+def queue_prompt(prompt, client_id, base_url):
+ print("queue_prompt")
+ p = {"prompt": prompt, "client_id": client_id}
+ data = json.dumps(p).encode("utf-8")
+ req = urllib.request.Request(f"{base_url}/prompt", data=data)
+ return json.loads(urllib.request.urlopen(req).read())
+
+
+def get_image(filename, subfolder, folder_type, base_url):
+ print("get_image")
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+ url_values = urllib.parse.urlencode(data)
+ with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
+ return response.read()
+
+
+def get_image_url(filename, subfolder, folder_type, base_url):
+ print("get_image")
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+ url_values = urllib.parse.urlencode(data)
+ return f"{base_url}/view?{url_values}"
+
+
+def get_history(prompt_id, base_url):
+ print("get_history")
+ with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
+ return json.loads(response.read())
+
+
+def get_images(ws, prompt, client_id, base_url):
+ prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
+ output_images = []
+ while True:
+ out = ws.recv()
+ if isinstance(out, str):
+ message = json.loads(out)
+ if message["type"] == "executing":
+ data = message["data"]
+ if data["node"] is None and data["prompt_id"] == prompt_id:
+ break # Execution is done
+ else:
+ continue # previews are binary data
+
+ history = get_history(prompt_id, base_url)[prompt_id]
+ for o in history["outputs"]:
+ for node_id in history["outputs"]:
+ node_output = history["outputs"][node_id]
+ if "images" in node_output:
+ for image in node_output["images"]:
+ url = get_image_url(
+ image["filename"], image["subfolder"], image["type"], base_url
+ )
+ output_images.append({"url": url})
+ return {"data": output_images}
+
+
+class ImageGenerationPayload(BaseModel):
+ prompt: str
+ negative_prompt: Optional[str] = ""
+ steps: Optional[int] = None
+ seed: Optional[int] = None
+ width: int
+ height: int
+ n: int = 1
+
+
+def comfyui_generate_image(
+ model: str, payload: ImageGenerationPayload, client_id, base_url
+):
+ host = base_url.replace("http://", "").replace("https://", "")
+
+ comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
+
+ comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
+ comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
+ comfyui_prompt["5"]["inputs"]["width"] = payload.width
+ comfyui_prompt["5"]["inputs"]["height"] = payload.height
+
+ # set the text prompt for our positive CLIPTextEncode
+ comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
+ comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
+
+ if payload.steps:
+ comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
+
+ comfyui_prompt["3"]["inputs"]["seed"] = (
+ payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+ )
+
+ try:
+ ws = websocket.WebSocket()
+ ws.connect(f"ws://{host}/ws?clientId={client_id}")
+ print("WebSocket connection established.")
+ except Exception as e:
+ print(f"Failed to connect to WebSocket server: {e}")
+ return None
+
+ try:
+ images = get_images(ws, comfyui_prompt, client_id, base_url)
+ except Exception as e:
+ print(f"Error while receiving images: {e}")
+ images = None
+
+ ws.close()
+
+ return images
diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte
index ee481402e..7282c184a 100644
--- a/src/lib/components/chat/Settings/Images.svelte
+++ b/src/lib/components/chat/Settings/Images.svelte
@@ -323,6 +323,7 @@
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel}
placeholder={$i18n.t('Select a model')}
+ required
>
{#if !selectedModel}
diff --git a/src/lib/components/common/ImagePreview.svelte b/src/lib/components/common/ImagePreview.svelte
index cf69327fa..badabebda 100644
--- a/src/lib/components/common/ImagePreview.svelte
+++ b/src/lib/components/common/ImagePreview.svelte
@@ -2,6 +2,22 @@
export let show = false;
export let src = '';
export let alt = '';
+
+ const downloadImage = (url, filename) => {
+ fetch(url)
+ .then((response) => response.blob())
+ .then((blob) => {
+ const objectUrl = window.URL.createObjectURL(blob);
+ const link = document.createElement('a');
+ link.href = objectUrl;
+ link.download = filename;
+ document.body.appendChild(link);
+ link.click();
+ document.body.removeChild(link);
+ window.URL.revokeObjectURL(objectUrl);
+ })
+ .catch((error) => console.error('Error downloading image:', error));
+ };
{#if show}
@@ -35,10 +51,7 @@