This commit is contained in:
Timothy Jaeryang Baek 2024-12-12 20:24:36 -08:00
parent 4311bb7b99
commit d8a01cb911

View File

@ -42,10 +42,10 @@ router = APIRouter()
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.ENGINE,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
@ -93,11 +93,13 @@ class ConfigForm(BaseModel):
async def update_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENGINE = form_data.engine
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.openai.OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
@ -132,10 +134,10 @@ async def update_config(
return {
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.ENGINE,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"openai": {
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
@ -166,7 +168,7 @@ def get_automatic1111_api_auth(request: Request):
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.ENGINE == "automatic1111":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
@ -177,7 +179,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
@ -194,7 +196,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.MODEL = model
if request.app.state.config.ENGINE in ["", "automatic1111"]:
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth()
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
@ -212,17 +214,17 @@ def set_image_model(request: Request, model: str):
def get_image_model():
if request.app.state.config.ENGINE == "openai":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.MODEL
if request.app.state.config.MODEL
else "dall-e-2"
)
elif request.app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
elif (
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
try:
r = requests.get(
@ -285,12 +287,12 @@ async def update_image_config(
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if request.app.state.config.ENGINE == "openai":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif request.app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
@ -336,8 +338,8 @@ def get_models(request: Request, user=Depends(get_verified_user)):
)
)
elif (
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
@ -433,10 +435,10 @@ async def image_generations(
r = None
try:
if request.app.state.config.ENGINE == "openai":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
headers = {}
headers["Authorization"] = (
f"Bearer {request.app.state.config.OPENAI_API_KEY}"
f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json"
@ -465,7 +467,7 @@ async def image_generations(
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
@ -485,7 +487,7 @@ async def image_generations(
return images
elif request.app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
@ -531,8 +533,8 @@ async def image_generations(
log.debug(f"images: {images}")
return images
elif (
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(form_data.model)