This commit is contained in:
Timothy Jaeryang Baek
2025-01-16 00:13:02 -08:00
parent 0360aa5520
commit 0425621494
5 changed files with 52 additions and 26 deletions

View File

@@ -1650,6 +1650,13 @@ ENABLE_IMAGE_GENERATION = PersistentConfig(
"image_generation.enable",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
"ENABLE_IMAGE_PROMPT_GENERATION",
"image_generation.prompt.enable",
os.environ.get("ENABLE_IMAGE_PROMPT_GENERATION", "true").lower() == "true",
)
AUTOMATIC1111_BASE_URL = PersistentConfig(
"AUTOMATIC1111_BASE_URL",
"image_generation.automatic1111.base_url",

View File

@@ -108,6 +108,7 @@ from open_webui.config import (
COMFYUI_WORKFLOW,
COMFYUI_WORKFLOW_NODES,
ENABLE_IMAGE_GENERATION,
ENABLE_IMAGE_PROMPT_GENERATION,
IMAGE_GENERATION_ENGINE,
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
@@ -575,6 +576,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY

View File

@@ -43,6 +43,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
@@ -86,6 +87,7 @@ class ComfyUIConfigForm(BaseModel):
class ConfigForm(BaseModel):
enabled: bool
engine: str
prompt_generation: bool
openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm
@@ -98,6 +100,10 @@ async def update_config(
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.prompt_generation
)
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.openai.OPENAI_API_BASE_URL
)
@@ -137,6 +143,7 @@ async def update_config(
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,

View File

@@ -504,38 +504,39 @@ async def chat_image_generation_handler(
messages = form_data["messages"]
user_message = get_last_user_message(messages)
prompt = ""
prompt = user_message
negative_prompt = ""
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
response = res["choices"][0]["message"]["content"]
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
log.exception(e)
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
system_message_content = ""
try: