diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e49c251a1..4ca32f7fa 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1587,6 +1587,12 @@ COMFYUI_BASE_URL = PersistentConfig( os.getenv("COMFYUI_BASE_URL", ""), ) +COMFYUI_API_KEY = PersistentConfig( + "COMFYUI_API_KEY", + "image_generation.comfyui.api_key", + os.getenv("COMFYUI_API_KEY", ""), +) + COMFYUI_DEFAULT_WORKFLOW = """ { "3": { diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 31604984f..64e95ea0c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -96,6 +96,7 @@ from open_webui.config import ( AUTOMATIC1111_SAMPLER, AUTOMATIC1111_SCHEDULER, COMFYUI_BASE_URL, + COMFYUI_API_KEY, COMFYUI_WORKFLOW, COMFYUI_WORKFLOW_NODES, ENABLE_IMAGE_GENERATION, @@ -557,6 +558,7 @@ app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_API_KEY = COMFYUI_API_KEY app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 3f51fbdb4..5778279c0 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -56,6 +56,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): }, "comfyui": { "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, @@ -77,6 +78,7 @@ class Automatic1111ConfigForm(BaseModel): class ComfyUIConfigForm(BaseModel): COMFYUI_BASE_URL: str + COMFYUI_API_KEY: str COMFYUI_WORKFLOW: str COMFYUI_WORKFLOW_NODES: list[dict] @@ -148,6 +150,7 @@ async def update_config( }, "comfyui": { "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, @@ -298,9 +301,8 @@ def get_models(request: Request, user=Depends(get_verified_user)): ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui - r = requests.get( - url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" - ) + headers = {"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"} + r = requests.get(url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info", headers=headers) info = r.json() workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) @@ -521,6 +523,7 @@ async def image_generations( form_data, user.id, request.app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_API_KEY, ) log.debug(f"res: {res}") diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py index 4c421d7c5..745d79635 100644 --- a/backend/open_webui/utils/images/comfyui.py +++ b/backend/open_webui/utils/images/comfyui.py @@ -16,14 +16,14 @@ log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) default_headers = {"User-Agent": "Mozilla/5.0"} -def queue_prompt(prompt, client_id, base_url): +def queue_prompt(prompt, client_id, base_url,api_key): log.info("queue_prompt") p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode("utf-8") log.debug(f"queue_prompt data: {data}") try: req = urllib.request.Request( - f"{base_url}/prompt", data=data, headers=default_headers + f"{base_url}/prompt", data=data, headers={**default_headers, "Authorization": f"Bearer {api_key}"} ) response = urllib.request.urlopen(req).read() return json.loads(response) @@ -32,12 +32,12 @@ def queue_prompt(prompt, client_id, base_url): raise e -def get_image(filename, subfolder, folder_type, base_url): +def get_image(filename, subfolder, folder_type, base_url, api_key): log.info("get_image") data = {"filename": filename, "subfolder": subfolder, "type": folder_type} url_values = urllib.parse.urlencode(data) req = urllib.request.Request( - f"{base_url}/view?{url_values}", headers=default_headers + f"{base_url}/view?{url_values}", headers={**default_headers, "Authorization": f"Bearer {api_key}"} ) with urllib.request.urlopen(req) as response: return response.read() @@ -50,18 +50,18 @@ def get_image_url(filename, subfolder, folder_type, base_url): return f"{base_url}/view?{url_values}" -def get_history(prompt_id, base_url): +def get_history(prompt_id, base_url, api_key): log.info("get_history") req = urllib.request.Request( - f"{base_url}/history/{prompt_id}", headers=default_headers + f"{base_url}/history/{prompt_id}", headers={**default_headers, "Authorization": f"Bearer {api_key}"} ) with urllib.request.urlopen(req) as response: return json.loads(response.read()) -def get_images(ws, prompt, client_id, base_url): - prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] +def get_images(ws, prompt, client_id, base_url, api_key): + prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"] output_images = [] while True: out = ws.recv() @@ -74,7 +74,7 @@ def get_images(ws, prompt, client_id, base_url): else: continue # previews are binary data - history = get_history(prompt_id, base_url)[prompt_id] + history = get_history(prompt_id, base_url, api_key)[prompt_id] for o in history["outputs"]: for node_id in history["outputs"]: node_output = history["outputs"][node_id] @@ -113,7 +113,7 @@ class ComfyUIGenerateImageForm(BaseModel): async def comfyui_generate_image( - model: str, payload: ComfyUIGenerateImageForm, client_id, base_url + model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") workflow = json.loads(payload.workflow.workflow) @@ -167,7 +167,8 @@ async def comfyui_generate_image( try: ws = websocket.WebSocket() - ws.connect(f"{ws_url}/ws?clientId={client_id}") + headers = {"Authorization": f"Bearer {api_key}"} + ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers) log.info("WebSocket connection established.") except Exception as e: log.exception(f"Failed to connect to WebSocket server: {e}") @@ -176,7 +177,7 @@ async def comfyui_generate_image( try: log.info("Sending workflow to WebSocket server.") log.info(f"Workflow: {workflow}") - images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url) + images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url, api_key) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index b0492f24b..02ff49bac 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -470,6 +470,20 @@ +