diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 4239f3f45..a418f2693 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -514,7 +514,7 @@ async def image_generations( data = ImageGenerationPayload(**data) - res = comfyui_generate_image( + res = await comfyui_generate_image( app.state.config.MODEL, data, user.id, diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 6c37f0c49..ec0f8c59e 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,3 +1,4 @@ +import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import uuid import json @@ -328,7 +329,7 @@ class ImageGenerationPayload(BaseModel): flux_fp8_clip: Optional[bool] = None -def comfyui_generate_image( +async def comfyui_generate_image( model: str, payload: ImageGenerationPayload, client_id, base_url ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") @@ -377,9 +378,9 @@ def comfyui_generate_image( comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype if payload.flux_fp8_clip: - comfyui_prompt["11"]["inputs"][ - "clip_name2" - ] = "t5xxl_fp8_e4m3fn.safetensors" + comfyui_prompt["11"]["inputs"]["clip_name2"] = ( + "t5xxl_fp8_e4m3fn.safetensors" + ) comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["width"] = payload.width @@ -397,7 +398,7 @@ def comfyui_generate_image( return None try: - images = get_images(ws, comfyui_prompt, client_id, base_url) + images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 831da783b..50de53a53 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -359,10 +359,10 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} - + if "metadata" in payload: del payload["metadata"] - + model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 3dc1cf7ee..9de19d3f6 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -192,7 +192,6 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: if (param := params.get(key, None)) is not None: form_data[value] = param - print(form_data) return form_data