This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 18:53:38 -08:00
parent ccdf51588e
commit 772f5ccd60
3 changed files with 180 additions and 246 deletions

View File

@ -697,11 +697,11 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if filter_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
@ -828,7 +828,7 @@ async def chat_completion_tools_handler(
models,
)
tools = get_tools(
webui_app,
app,
tool_ids,
user,
{
@ -1406,7 +1406,7 @@ async def commit_session_after_request(request: Request, call_next):
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY
request.state.enable_api_key = app.state.config.ENABLE_API_KEY
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
@ -1913,11 +1913,11 @@ async def get_all_models():
]
def get_function_module_by_id(function_id):
if function_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[function_id]
if function_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
webui_app.state.FUNCTIONS[function_id] = function_module
app.state.FUNCTIONS[function_id] = function_module
for model in models:
action_ids = [
@ -1953,7 +1953,7 @@ async def get_models(user=Depends(get_verified_user)):
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
]
model_order_list = webui_app.state.config.MODEL_ORDER_LIST
model_order_list = app.state.config.MODEL_ORDER_LIST
if model_order_list:
model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)}
# Sort models by order list priority, with fallback for those not in the list
@ -2229,11 +2229,11 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if filter_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
@ -2340,11 +2340,11 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
if action_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
@ -2448,17 +2448,17 @@ async def get_app_config(request: Request):
},
"features": {
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_ldap": webui_app.state.config.ENABLE_LDAP,
"enable_api_key": webui_app.state.config.ENABLE_API_KEY,
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_ldap": app.state.config.ENABLE_LDAP,
"enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
**(
{
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
}
@ -2468,8 +2468,8 @@ async def get_app_config(request: Request):
},
**(
{
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
@ -2484,7 +2484,7 @@ async def get_app_config(request: Request):
"max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT,
},
"permissions": {**webui_app.state.config.USER_PERMISSIONS},
"permissions": {**app.state.config.USER_PERMISSIONS},
}
if user is not None
else {}
@ -2506,7 +2506,7 @@ async def get_webhook_url(user=Depends(get_admin_user)):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {"url": app.state.config.WEBHOOK_URL}

View File

@ -9,6 +9,18 @@ from pathlib import Path
from typing import Optional
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
@ -16,48 +28,36 @@ from open_webui.utils.images.comfyui import (
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
router = APIRouter()
@app.get("/config")
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"enabled": app.state.config.ENABLED,
"engine": app.state.config.ENGINE,
"enabled": request.app.state.config.ENABLED,
"engine": request.app.state.config.ENGINE,
"openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
@ -89,133 +89,150 @@ class ConfigForm(BaseModel):
comfyui: ComfyUIConfigForm
@app.post("/config/update")
async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
app.state.config.ENGINE = form_data.engine
app.state.config.ENABLED = form_data.enabled
@router.post("/config/update")
async def update_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENGINE = form_data.engine
request.app.state.config.ENABLED = form_data.enabled
app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
app.state.config.AUTOMATIC1111_BASE_URL = (
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
app.state.config.AUTOMATIC1111_API_AUTH = (
request.app.state.config.AUTOMATIC1111_API_AUTH = (
form_data.automatic1111.AUTOMATIC1111_API_AUTH
)
app.state.config.AUTOMATIC1111_CFG_SCALE = (
request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
else None
)
app.state.config.AUTOMATIC1111_SAMPLER = (
request.app.state.config.AUTOMATIC1111_SAMPLER = (
form_data.automatic1111.AUTOMATIC1111_SAMPLER
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
else None
)
app.state.config.AUTOMATIC1111_SCHEDULER = (
request.app.state.config.AUTOMATIC1111_SCHEDULER = (
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
else None
)
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
)
return {
"enabled": app.state.config.ENABLED,
"engine": app.state.config.ENGINE,
"enabled": request.app.state.config.ENABLED,
"engine": request.app.state.config.ENGINE,
"openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH is None:
def get_automatic1111_api_auth(request: Request):
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@app.get("/config/url/verify")
async def verify_url(user=Depends(get_admin_user)):
if app.state.config.ENGINE == "automatic1111":
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
r.raise_for_status()
return True
except Exception:
app.state.config.ENABLED = False
request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.ENGINE == "comfyui":
try:
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
r.raise_for_status()
return True
except Exception:
app.state.config.ENABLED = False
request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else:
return True
def set_image_model(model: str):
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
app.state.config.MODEL = model
if app.state.config.ENGINE in ["", "automatic1111"]:
request.app.state.config.MODEL = model
if request.app.state.config.ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth()
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return app.state.config.MODEL
return request.app.state.config.MODEL
def get_image_model():
if app.state.config.ENGINE == "openai":
return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
elif app.state.config.ENGINE == "comfyui":
return app.state.config.MODEL if app.state.config.MODEL else ""
elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
if request.app.state.config.ENGINE == "openai":
return (
request.app.state.config.MODEL
if request.app.state.config.MODEL
else "dall-e-2"
)
elif request.app.state.config.ENGINE == "comfyui":
return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
elif (
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
):
try:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
)
options = r.json()
return options["sd_model_checkpoint"]
except Exception as e:
app.state.config.ENABLED = False
request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -225,23 +242,25 @@ class ImageConfigForm(BaseModel):
IMAGE_STEPS: int
@app.get("/image/config")
async def get_image_config(user=Depends(get_admin_user)):
@router.get("/image/config")
async def get_image_config(request: Request, user=Depends(get_admin_user)):
return {
"MODEL": app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"MODEL": request.app.state.config.MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@app.post("/image/config/update")
async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
@router.post("/image/config/update")
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(form_data.MODEL)
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
if re.match(pattern, form_data.IMAGE_SIZE):
app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
status_code=400,
@ -249,7 +268,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
)
if form_data.IMAGE_STEPS >= 0:
app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else:
raise HTTPException(
status_code=400,
@ -257,29 +276,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
)
return {
"MODEL": app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"MODEL": request.app.state.config.MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@app.get("/models")
def get_models(user=Depends(get_verified_user)):
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if app.state.config.ENGINE == "openai":
if request.app.state.config.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.ENGINE == "comfyui":
# TODO - get models from comfyui
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
info = r.json()
workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None
for node in app.state.config.COMFYUI_WORKFLOW_NODES:
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model":
if node["node_ids"]:
model_node_id = node["node_ids"][0]
@ -315,10 +336,11 @@ def get_models(user=Depends(get_verified_user)):
)
)
elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
):
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()},
)
models = r.json()
@ -329,7 +351,7 @@ def get_models(user=Depends(get_verified_user)):
)
)
except Exception as e:
app.state.config.ENABLED = False
request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -401,18 +423,21 @@ def save_url_image(url):
return None
@app.post("/generations")
@router.post("/generations")
async def image_generations(
request: Request,
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
r = None
try:
if app.state.config.ENGINE == "openai":
if request.app.state.config.ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Authorization"] = (
f"Bearer {request.app.state.config.OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
@ -423,14 +448,16 @@ async def image_generations(
data = {
"model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
request.app.state.config.MODEL
if request.app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
}
@ -438,7 +465,7 @@ async def image_generations(
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
@ -458,7 +485,7 @@ async def image_generations(
return images
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
@ -466,8 +493,8 @@ async def image_generations(
"n": form_data.n,
}
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
@ -476,18 +503,18 @@ async def image_generations(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": app.state.config.COMFYUI_WORKFLOW,
"nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_generate_image(
app.state.config.MODEL,
request.app.state.config.MODEL,
form_data,
user.id,
app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
@ -504,7 +531,8 @@ async def image_generations(
log.debug(f"images: {images}")
return images
elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
):
if form_data.model:
set_image_model(form_data.model)
@ -516,25 +544,25 @@ async def image_generations(
"height": height,
}
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
if app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
if request.app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
if app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
if request.app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth()},
)

View File

@ -1,94 +0,0 @@
import inspect
import json
import logging
import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.socket.main import get_event_call, get_event_emitter
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.routers import (
auths,
chats,
folders,
configs,
groups,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
)
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
MODEL_ORDER_LIST,
ENABLE_COMMUNITY_SHARING,
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_API_KEY,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
ENABLE_LDAP,
LDAP_SERVER_LABEL,
LDAP_SERVER_HOST,
LDAP_SERVER_PORT,
LDAP_ATTRIBUTE_FOR_USERNAME,
LDAP_SEARCH_FILTERS,
LDAP_SEARCH_BASE,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
LDAP_USE_TLS,
LDAP_CA_CERT_FILE,
LDAP_CIPHERS,
AppConfig,
)
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.tools import get_tools
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])