From a10302d909b1101a60b0b3c112acd4546c68af6d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 15 Jan 2025 23:32:13 -0800 Subject: [PATCH] enh: image generation toggle --- backend/open_webui/utils/middleware.py | 69 +++++++++++++++++++ src/lib/components/chat/Chat.svelte | 4 ++ src/lib/components/chat/MessageInput.svelte | 4 ++ .../chat/MessageInput/InputMenu.svelte | 37 +++++++++- src/lib/components/chat/Placeholder.svelte | 2 + src/lib/components/icons/PhotoSolid.svelte | 11 +++ 6 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 src/lib/components/icons/PhotoSolid.svelte diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 5980d4af0..221847d07 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -31,6 +31,9 @@ from open_webui.routers.tasks import ( generate_chat_tags, ) from open_webui.routers.retrieval import process_web_search, SearchForm +from open_webui.routers.images import image_generations, GenerateImageForm + + from open_webui.utils.webhook import post_webhook @@ -486,6 +489,67 @@ async def chat_web_search_handler( return form_data +async def chat_image_generation_handler( + request: Request, form_data: dict, extra_params: dict, user +): + __event_emitter__ = extra_params["__event_emitter__"] + await __event_emitter__( + { + "type": "status", + "data": {"description": "Generating an image", "done": False}, + } + ) + + messages = form_data["messages"] + user_message = get_last_user_message(messages) + + system_message_content = "" + + try: + images = await image_generations( + request=request, + form_data=GenerateImageForm(**{"prompt": user_message}), + user=user, + ) + + await __event_emitter__( + { + "type": "status", + "data": {"description": "Generated an image", "done": True}, + } + ) + + for image in images: + await __event_emitter__( + { + "type": "message", + "data": {"content": f"![Generated Image]({image['url']})"}, + } + ) + + system_message_content = "User is shown the generated image, tell the user that the image has been generated" + except Exception as e: + log.exception(e) + await __event_emitter__( + { + "type": "status", + "data": { + "description": f"An error occured while generating an image", + "done": True, + }, + } + ) + + system_message_content = "Unable to generate an image, tell the user that an error occured" + + if system_message_content: + form_data["messages"] = add_or_update_system_message( + system_message_content, form_data["messages"] + ) + + return form_data + + async def chat_completion_files_handler( request: Request, body: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: @@ -640,6 +704,11 @@ async def process_chat_payload(request, form_data, metadata, user, model): request, form_data, extra_params, user ) + if "image_generation" in features and features["image_generation"]: + form_data = await chat_image_generation_handler( + request, form_data, extra_params, user + ) + try: form_data, flags = await chat_completion_filter_functions_handler( request, form_data, model, extra_params diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index bca957cca..adce45bd3 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -111,6 +111,7 @@ $: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels; let selectedToolIds = []; + let imageGenerationEnabled = false; let webSearchEnabled = false; let chat = null; @@ -1533,6 +1534,7 @@ files: (files?.length ?? 0) > 0 ? files : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, features: { + image_generation: imageGenerationEnabled, web_search: webSearchEnabled }, @@ -1935,6 +1937,7 @@ bind:prompt bind:autoScroll bind:selectedToolIds + bind:imageGenerationEnabled bind:webSearchEnabled bind:atSelectedModel transparentBackground={$settings?.backgroundImageUrl ?? false} @@ -1985,6 +1988,7 @@ bind:prompt bind:autoScroll bind:selectedToolIds + bind:imageGenerationEnabled bind:webSearchEnabled bind:atSelectedModel transparentBackground={$settings?.backgroundImageUrl ?? false} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 5ec2a6d84..44a32ce3e 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -62,12 +62,15 @@ export let files = []; export let selectedToolIds = []; + + export let imageGenerationEnabled = false; export let webSearchEnabled = false; $: onChange({ prompt, files, selectedToolIds, + imageGenerationEnabled, webSearchEnabled }); @@ -642,6 +645,7 @@
{/if} - {#if $config?.features?.enable_web_search && ($user.role === 'admin' || $user?.permissions?.features?.web_search)} + {#if showImageGeneration} + + {/if} + + {#if showWebSearch} + {/if} + {#if showImageGeneration || showWebSearch}
{/if} diff --git a/src/lib/components/chat/Placeholder.svelte b/src/lib/components/chat/Placeholder.svelte index 6782ea628..336e3002c 100644 --- a/src/lib/components/chat/Placeholder.svelte +++ b/src/lib/components/chat/Placeholder.svelte @@ -34,6 +34,7 @@ export let files = []; export let selectedToolIds = []; + export let imageGenerationEnabled = false; export let webSearchEnabled = false; let models = []; @@ -194,6 +195,7 @@ bind:prompt bind:autoScroll bind:selectedToolIds + bind:imageGenerationEnabled bind:webSearchEnabled bind:atSelectedModel {transparentBackground} diff --git a/src/lib/components/icons/PhotoSolid.svelte b/src/lib/components/icons/PhotoSolid.svelte new file mode 100644 index 000000000..a004e9b5d --- /dev/null +++ b/src/lib/components/icons/PhotoSolid.svelte @@ -0,0 +1,11 @@ + + + + +