From 0360aa5520550ae5c733a77464d0a14901a0f3ee Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 16 Jan 2025 00:06:37 -0800 Subject: [PATCH] enh: image prompt enhancer --- backend/open_webui/config.py | 26 ++++++++ backend/open_webui/constants.py | 1 + backend/open_webui/main.py | 5 ++ backend/open_webui/routers/tasks.py | 65 +++++++++++++++++++ backend/open_webui/utils/middleware.py | 35 +++++++++- backend/open_webui/utils/task.py | 18 +++++ .../admin/Settings/Interface.svelte | 17 +++++ 7 files changed, 166 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e32fc1a17..df96344ab 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1055,6 +1055,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } {{MESSAGES:END:6}} """ +IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", + "task.image.prompt_template", + os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task: +Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements. + +### Guidelines: +- Be descriptive and detailed, focusing on the most important aspects of the image. +- Avoid making assumptions or adding information not present in the image. +- Use the chat's primary language; default to English if multilingual. +- If the image is too complex, focus on the most prominent elements. + +### Output: +Strictly return in JSON format: +{ + "prompt": "Your detailed description here." +} + +### Chat History: + +{{MESSAGES:END:6}} +""" + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index c5fdfabfb..cb65e0d77 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -113,6 +113,7 @@ class TASKS(str, Enum): TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" + IMAGE_PROMPT_GENERATION = "image_prompt_generation" AUTOCOMPLETE_GENERATION = "autocomplete_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6414bacca..b13f957a5 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -255,6 +255,7 @@ from open_webui.config import ( ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, QUERY_GENERATION_PROMPT_TEMPLATE, AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, @@ -644,6 +645,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE +app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE +) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 7d14a9d18..6d7343c8a 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, query_generation_template, + image_prompt_generation_template, autocomplete_generation_template, tags_generation_template, emoji_generation_template, @@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, + DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, @@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, @@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str @@ -114,6 +118,7 @@ async def update_task_config( "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, @@ -256,6 +261,66 @@ async def generate_chat_tags( ) +@router.post("/image_prompt/completions") +async def generate_image_prompt( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating image prompt using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE + + content = image_prompt_generation_template( + template, + form_data["messages"], + user={ + "name": user.name, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.IMAGE_PROMPT_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + @router.post("/queries/completions") async def generate_queries( request: Request, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 221847d07..61704513d 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -28,6 +28,7 @@ from open_webui.socket.main import ( from open_webui.routers.tasks import ( generate_queries, generate_title, + generate_image_prompt, generate_chat_tags, ) from open_webui.routers.retrieval import process_web_search, SearchForm @@ -503,12 +504,44 @@ async def chat_image_generation_handler( messages = form_data["messages"] user_message = get_last_user_message(messages) + prompt = "" + negative_prompt = "" + + try: + res = await generate_image_prompt( + request, + { + "model": form_data["model"], + "messages": messages, + }, + user, + ) + + response = res["choices"][0]["message"]["content"] + + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + response = response[bracket_start:bracket_end] + response = json.loads(response) + prompt = response.get("prompt", []) + except Exception as e: + prompt = user_message + + except Exception as e: + log.exception(e) + prompt = user_message + system_message_content = "" try: images = await image_generations( request=request, - form_data=GenerateImageForm(**{"prompt": user_message}), + form_data=GenerateImageForm(**{"prompt": prompt}), user=user, ) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index ebb7483ba..f5ba75ebe 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -217,6 +217,24 @@ def tags_generation_template( return template +def image_prompt_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def emoji_generation_template( template: str, prompt: str, user: Optional[dict] = None ) -> str: diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 9c669dae5..055acbf80 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -24,6 +24,7 @@ TASK_MODEL: '', TASK_MODEL_EXTERNAL: '', TITLE_GENERATION_PROMPT_TEMPLATE: '', + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '', ENABLE_AUTOCOMPLETE_GENERATION: true, AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1, TAGS_GENERATION_PROMPT_TEMPLATE: '', @@ -140,6 +141,22 @@ +
+
{$i18n.t('Image Prompt Generation Prompt')}
+ + +