enh: image prompt enhancer

This commit is contained in:
Timothy Jaeryang Baek 2025-01-16 00:06:37 -08:00
parent d3a5b9c127
commit 0360aa5520
7 changed files with 166 additions and 1 deletions

View File

@ -1055,6 +1055,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
{{MESSAGES:END:6}}
</chat_history>"""
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE",
"task.image.prompt_template",
os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task:
Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements.
### Guidelines:
- Be descriptive and detailed, focusing on the most important aspects of the image.
- Avoid making assumptions or adding information not present in the image.
- Use the chat's primary language; default to English if multilingual.
- If the image is too complex, focus on the most prominent elements.
### Output:
Strictly return in JSON format:
{
"prompt": "Your detailed description here."
}
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",

View File

@ -113,6 +113,7 @@ class TASKS(str, Enum):
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

View File

@ -255,6 +255,7 @@ from open_webui.config import (
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
QUERY_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@ -644,6 +645,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)

View File

@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
@ -114,6 +118,7 @@ async def update_task_config(
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@ -256,6 +261,66 @@ async def generate_chat_tags(
)
@router.post("/image_prompt/completions")
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating image prompt using model {task_model_id} for user {user.email} "
)
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
content = image_prompt_generation_template(
template,
form_data["messages"],
user={
"name": user.name,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)

View File

@ -28,6 +28,7 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
@ -503,12 +504,44 @@ async def chat_image_generation_handler(
messages = form_data["messages"]
user_message = get_last_user_message(messages)
prompt = ""
negative_prompt = ""
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
system_message_content = ""
try:
images = await image_generations(
request=request,
form_data=GenerateImageForm(**{"prompt": user_message}),
form_data=GenerateImageForm(**{"prompt": prompt}),
user=user,
)

View File

@ -217,6 +217,24 @@ def tags_generation_template(
return template
def image_prompt_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def emoji_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:

View File

@ -24,6 +24,7 @@
TASK_MODEL: '',
TASK_MODEL_EXTERNAL: '',
TITLE_GENERATION_PROMPT_TEMPLATE: '',
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_AUTOCOMPLETE_GENERATION: true,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
TAGS_GENERATION_PROMPT_TEMPLATE: '',
@ -140,6 +141,22 @@
</Tooltip>
</div>
<div class="mt-3">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Image Prompt Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
/>
</Tooltip>
</div>
<hr class=" border-gray-50 dark:border-gray-850 my-3" />
<div class="my-3 flex w-full items-center justify-between">