This commit is contained in:
Timothy Jaeryang Baek 2024-12-12 20:26:28 -08:00
parent d8a01cb911
commit 8c38708827

View File

@ -195,7 +195,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.MODEL = model
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth()
r = requests.get(
@ -210,18 +210,22 @@ def set_image_model(request: Request, model: str):
json=options,
headers={"authorization": api_auth},
)
return request.app.state.config.MODEL
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model():
def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.MODEL
if request.app.state.config.MODEL
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else ""
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
@ -247,7 +251,7 @@ class ImageConfigForm(BaseModel):
@router.get("/image/config")
async def get_image_config(request: Request, user=Depends(get_admin_user)):
return {
"MODEL": request.app.state.config.MODEL,
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@ -278,7 +282,7 @@ async def update_image_config(
)
return {
"MODEL": request.app.state.config.MODEL,
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@ -450,8 +454,8 @@ async def image_generations(
data = {
"model": (
request.app.state.config.MODEL
if request.app.state.config.MODEL != ""
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
@ -513,7 +517,7 @@ async def image_generations(
}
)
res = await comfyui_generate_image(
request.app.state.config.MODEL,
request.app.state.config.IMAGE_GENERATION_MODEL,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,