From e56b5c063ca72fb5fa3f37965730b63e078a42bd Mon Sep 17 00:00:00 2001 From: JoaoCostaIFG Date: Tue, 18 Feb 2025 22:39:32 +0000 Subject: [PATCH 1/2] feat: add Google Imagen/Gemini API image generation Adds support for Gemini API as an image generation backend. By setting the API Base URL to something like 'https://generativelanguage.googleapis.com/v1beta' and providing their API Key, users should be able to start generating images using models like 'imagen-3.0-generate-002'. --- backend/open_webui/config.py | 14 +++++ backend/open_webui/main.py | 5 ++ backend/open_webui/routers/images.py | 63 +++++++++++++++++++ .../components/admin/Settings/Images.svelte | 22 +++++++ 4 files changed, 104 insertions(+) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index adfdcfec8..0b147cf12 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -767,6 +767,9 @@ ENABLE_OPENAI_API = PersistentConfig( OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") +GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "") + if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -2064,6 +2067,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig( os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), ) +IMAGES_GEMINI_API_BASE_URL = PersistentConfig( + "IMAGES_GEMINI_API_BASE_URL", + "image_generation.gemini.api_base_url", + os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), +) +IMAGES_GEMINI_API_KEY = PersistentConfig( + "IMAGES_GEMINI_API_KEY", + "image_generation.gemini.api_key", + os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY), +) + IMAGE_SIZE = PersistentConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a36323151..22b0526da 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -125,6 +125,8 @@ from open_webui.config import ( IMAGE_STEPS, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, + IMAGES_GEMINI_API_BASE_URL, + IMAGES_GEMINI_API_KEY, # Audio AUDIO_STT_ENGINE, AUDIO_STT_MODEL, @@ -631,6 +633,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL +app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY + app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 4046773de..4c68442b7 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel): COMFYUI_WORKFLOW_NODES: list[dict] +class GeminiConfigForm(BaseModel): + GEMINI_API_BASE_URL: str + GEMINI_API_KEY: str + + class ConfigForm(BaseModel): enabled: bool engine: str @@ -85,6 +94,7 @@ class ConfigForm(BaseModel): openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm + gemini: GeminiConfigForm @router.post("/config/update") @@ -103,6 +113,11 @@ async def update_config( ) request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_GEMINI_API_BASE_URL = ( + form_data.gemini.GEMINI_API_BASE_URL + ) + request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) @@ -155,6 +170,10 @@ async def update_config( "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -224,6 +243,12 @@ def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_MODEL else "dall-e-2" ) + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "imagen-3.0-generate-002" + ) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return ( request.app.state.config.IMAGE_GENERATION_MODEL @@ -299,6 +324,10 @@ def get_models(request: Request, user=Depends(get_verified_user)): {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return [ + {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"}, + ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui headers = { @@ -483,6 +512,40 @@ async def image_generations( images.append({"url": url}) return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + headers = {} + headers["Content-Type"] = "application/json" + api_key = request.app.state.config.IMAGES_GEMINI_API_KEY + model = get_image_model(request) + data = { + "instances": {"prompt": form_data.prompt}, + "parameters": { + "sampleCount": form_data.n, + "outputOptions": {"mimeType": "image/png"}, + }, + } + + # Use asyncio.to_thread for the requests.post call + r = await asyncio.to_thread( + requests.post, + url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict?key={api_key}", + json=data, + headers=headers, + ) + + r.raise_for_status() + res = r.json() + + images = [] + for image in res["predictions"]: + image_data, content_type = load_b64_image_data( + image["bytesBase64Encoded"] + ) + url = upload_image(request, data, image_data, content_type, user) + images.append({"url": url}) + + return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 4277c4ebd..590886702 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -261,6 +261,9 @@ } else if (config.engine === 'openai' && config.openai.OPENAI_API_KEY === '') { toast.error($i18n.t('OpenAI API Key is required.')); config.enabled = false; + } else if (config.engine === 'gemini' && config.gemini.GEMINI_API_KEY === '') { + toast.error($i18n.t('Gemini API Key is required.')); + config.enabled = false; } } @@ -294,6 +297,7 @@ + @@ -605,6 +609,24 @@ /> + {:else if config?.engine === 'gemini'} +
+
{$i18n.t('Gemini API Config')}
+ +
+ + + +
+
{/if} From 918764a4f7093ec0838ceeff358356e0942978e6 Mon Sep 17 00:00:00 2001 From: JoaoCostaIFG Date: Wed, 19 Feb 2025 00:00:54 +0000 Subject: [PATCH 2/2] fix: Use x-goog-api-key header for Gemini image generation Place the API key in a header instead of a query parameter. This avoids leaking the API key in logs on request failure, etc... --- backend/open_webui/routers/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 4c68442b7..3288ec6d8 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -515,7 +515,8 @@ async def image_generations( elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": headers = {} headers["Content-Type"] = "application/json" - api_key = request.app.state.config.IMAGES_GEMINI_API_KEY + headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY + model = get_image_model(request) data = { "instances": {"prompt": form_data.prompt}, @@ -528,7 +529,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict?key={api_key}", + url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict", json=data, headers=headers, )