mirror of
https://github.com/open-webui/open-webui
synced 2025-01-19 09:16:44 +00:00
enh: image prompt enhancer
This commit is contained in:
parent
d3a5b9c127
commit
0360aa5520
@ -1055,6 +1055,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.image.prompt_template",
|
||||
os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements.
|
||||
|
||||
### Guidelines:
|
||||
- Be descriptive and detailed, focusing on the most important aspects of the image.
|
||||
- Avoid making assumptions or adding information not present in the image.
|
||||
- Use the chat's primary language; default to English if multilingual.
|
||||
- If the image is too complex, focus on the most prominent elements.
|
||||
|
||||
### Output:
|
||||
Strictly return in JSON format:
|
||||
{
|
||||
"prompt": "Your detailed description here."
|
||||
}
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||
"ENABLE_TAGS_GENERATION",
|
||||
"task.tags.enable",
|
||||
|
@ -113,6 +113,7 @@ class TASKS(str, Enum):
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
|
||||
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
MOA_RESPONSE_GENERATION = "moa_response_generation"
|
||||
|
@ -255,6 +255,7 @@ from open_webui.config import (
|
||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||
@ -644,6 +645,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.task import (
|
||||
title_generation_template,
|
||||
query_generation_template,
|
||||
image_prompt_generation_template,
|
||||
autocomplete_generation_template,
|
||||
tags_generation_template,
|
||||
emoji_generation_template,
|
||||
@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id
|
||||
from open_webui.config import (
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
|
||||
@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
@ -114,6 +118,7 @@ async def update_task_config(
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
@ -256,6 +261,66 @@ async def generate_chat_tags(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/image_prompt/completions")
|
||||
async def generate_image_prompt(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating image prompt using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = image_prompt_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
user={
|
||||
"name": user.name,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/queries/completions")
|
||||
async def generate_queries(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
|
@ -28,6 +28,7 @@ from open_webui.socket.main import (
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
@ -503,12 +504,44 @@ async def chat_image_generation_handler(
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
|
||||
try:
|
||||
res = await generate_image_prompt(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
response = res["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = response.find("{")
|
||||
bracket_end = response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
response = response[bracket_start:bracket_end]
|
||||
response = json.loads(response)
|
||||
prompt = response.get("prompt", [])
|
||||
except Exception as e:
|
||||
prompt = user_message
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
prompt = user_message
|
||||
|
||||
system_message_content = ""
|
||||
|
||||
try:
|
||||
images = await image_generations(
|
||||
request=request,
|
||||
form_data=GenerateImageForm(**{"prompt": user_message}),
|
||||
form_data=GenerateImageForm(**{"prompt": prompt}),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
@ -217,6 +217,24 @@ def tags_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def image_prompt_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def emoji_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
@ -24,6 +24,7 @@
|
||||
TASK_MODEL: '',
|
||||
TASK_MODEL_EXTERNAL: '',
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: '',
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: true,
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: '',
|
||||
@ -140,6 +141,22 @@
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
<div class="mt-3">
|
||||
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Image Prompt Generation Prompt')}</div>
|
||||
|
||||
<Tooltip
|
||||
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
|
||||
placement="top-start"
|
||||
>
|
||||
<Textarea
|
||||
bind:value={taskConfig.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE}
|
||||
placeholder={$i18n.t(
|
||||
'Leave empty to use the default prompt, or enter a custom prompt'
|
||||
)}
|
||||
/>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
<hr class=" border-gray-50 dark:border-gray-850 my-3" />
|
||||
|
||||
<div class="my-3 flex w-full items-center justify-between">
|
||||
|
Loading…
Reference in New Issue
Block a user