diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index df96344ab..12812ae0a 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1650,6 +1650,13 @@ ENABLE_IMAGE_GENERATION = PersistentConfig( "image_generation.enable", os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", ) + +ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig( + "ENABLE_IMAGE_PROMPT_GENERATION", + "image_generation.prompt.enable", + os.environ.get("ENABLE_IMAGE_PROMPT_GENERATION", "true").lower() == "true", +) + AUTOMATIC1111_BASE_URL = PersistentConfig( "AUTOMATIC1111_BASE_URL", "image_generation.automatic1111.base_url", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index b13f957a5..00270aabc 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -108,6 +108,7 @@ from open_webui.config import ( COMFYUI_WORKFLOW, COMFYUI_WORKFLOW_NODES, ENABLE_IMAGE_GENERATION, + ENABLE_IMAGE_PROMPT_GENERATION, IMAGE_GENERATION_ENGINE, IMAGE_GENERATION_MODEL, IMAGE_SIZE, @@ -575,6 +576,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION +app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index c40d12522..f4833d0b5 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -43,6 +43,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, + "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, @@ -86,6 +87,7 @@ class ComfyUIConfigForm(BaseModel): class ConfigForm(BaseModel): enabled: bool engine: str + prompt_generation: bool openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm @@ -98,6 +100,10 @@ async def update_config( request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled + request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( + form_data.prompt_generation + ) + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( form_data.openai.OPENAI_API_BASE_URL ) @@ -137,6 +143,7 @@ async def update_config( return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, + "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 61704513d..96d9ede0f 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -504,38 +504,39 @@ async def chat_image_generation_handler( messages = form_data["messages"] user_message = get_last_user_message(messages) - prompt = "" + prompt = user_message negative_prompt = "" - try: - res = await generate_image_prompt( - request, - { - "model": form_data["model"], - "messages": messages, - }, - user, - ) - - response = res["choices"][0]["message"]["content"] - + if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: try: - bracket_start = response.find("{") - bracket_end = response.rfind("}") + 1 + res = await generate_image_prompt( + request, + { + "model": form_data["model"], + "messages": messages, + }, + user, + ) - if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") + response = res["choices"][0]["message"]["content"] + + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + response = response[bracket_start:bracket_end] + response = json.loads(response) + prompt = response.get("prompt", []) + except Exception as e: + prompt = user_message - response = response[bracket_start:bracket_end] - response = json.loads(response) - prompt = response.get("prompt", []) except Exception as e: + log.exception(e) prompt = user_message - except Exception as e: - log.exception(e) - prompt = user_message - system_message_content = "" try: diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index b99eb4631..a875709df 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -234,7 +234,7 @@
{$i18n.t('Image Settings')}
-
+
{$i18n.t('Image Generation (Experimental)')}
@@ -271,7 +271,16 @@
-
+ {#if config.enabled} +
+
{$i18n.t('Image Prompt Generation')}
+
+ +
+
+ {/if} + +
{$i18n.t('Image Generation Engine')}