This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 18:53:38 -08:00
parent ccdf51588e
commit 772f5ccd60
3 changed files with 180 additions and 246 deletions

View File

@ -697,11 +697,11 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
if not filter: if not filter:
continue continue
if filter_id in webui_app.state.FUNCTIONS: if filter_id in app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id] function_module = app.state.FUNCTIONS[filter_id]
else: else:
function_module, _, _ = load_function_module_by_id(filter_id) function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable # Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"): if hasattr(function_module, "file_handler"):
@ -828,7 +828,7 @@ async def chat_completion_tools_handler(
models, models,
) )
tools = get_tools( tools = get_tools(
webui_app, app,
tool_ids, tool_ids,
user, user,
{ {
@ -1406,7 +1406,7 @@ async def commit_session_after_request(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
start_time = int(time.time()) start_time = int(time.time())
request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY request.state.enable_api_key = app.state.config.ENABLE_API_KEY
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time) response.headers["X-Process-Time"] = str(process_time)
@ -1913,11 +1913,11 @@ async def get_all_models():
] ]
def get_function_module_by_id(function_id): def get_function_module_by_id(function_id):
if function_id in webui_app.state.FUNCTIONS: if function_id in app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[function_id] function_module = app.state.FUNCTIONS[function_id]
else: else:
function_module, _, _ = load_function_module_by_id(function_id) function_module, _, _ = load_function_module_by_id(function_id)
webui_app.state.FUNCTIONS[function_id] = function_module app.state.FUNCTIONS[function_id] = function_module
for model in models: for model in models:
action_ids = [ action_ids = [
@ -1953,7 +1953,7 @@ async def get_models(user=Depends(get_verified_user)):
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
] ]
model_order_list = webui_app.state.config.MODEL_ORDER_LIST model_order_list = app.state.config.MODEL_ORDER_LIST
if model_order_list: if model_order_list:
model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)}
# Sort models by order list priority, with fallback for those not in the list # Sort models by order list priority, with fallback for those not in the list
@ -2229,11 +2229,11 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
if not filter: if not filter:
continue continue
if filter_id in webui_app.state.FUNCTIONS: if filter_id in app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id] function_module = app.state.FUNCTIONS[filter_id]
else: else:
function_module, _, _ = load_function_module_by_id(filter_id) function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id) valves = Functions.get_function_valves_by_id(filter_id)
@ -2340,11 +2340,11 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
} }
) )
if action_id in webui_app.state.FUNCTIONS: if action_id in app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id] function_module = app.state.FUNCTIONS[action_id]
else: else:
function_module, _, _ = load_function_module_by_id(action_id) function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id) valves = Functions.get_function_valves_by_id(action_id)
@ -2448,17 +2448,17 @@ async def get_app_config(request: Request):
}, },
"features": { "features": {
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_ldap": webui_app.state.config.ENABLE_LDAP, "enable_ldap": app.state.config.ENABLE_LDAP,
"enable_api_key": webui_app.state.config.ENABLE_API_KEY, "enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": webui_app.state.config.ENABLE_SIGNUP, "enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
**( **(
{ {
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED, "enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
} }
@ -2468,8 +2468,8 @@ async def get_app_config(request: Request):
}, },
**( **(
{ {
"default_models": webui_app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"audio": { "audio": {
"tts": { "tts": {
"engine": audio_app.state.config.TTS_ENGINE, "engine": audio_app.state.config.TTS_ENGINE,
@ -2484,7 +2484,7 @@ async def get_app_config(request: Request):
"max_size": retrieval_app.state.config.FILE_MAX_SIZE, "max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT, "max_count": retrieval_app.state.config.FILE_MAX_COUNT,
}, },
"permissions": {**webui_app.state.config.USER_PERMISSIONS}, "permissions": {**app.state.config.USER_PERMISSIONS},
} }
if user is not None if user is not None
else {} else {}
@ -2506,7 +2506,7 @@ async def get_webhook_url(user=Depends(get_admin_user)):
@app.post("/api/webhook") @app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {"url": app.state.config.WEBHOOK_URL} return {"url": app.state.config.WEBHOOK_URL}

View File

@ -9,6 +9,18 @@ from pathlib import Path
from typing import Optional from typing import Optional
import requests import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import ( from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm, ComfyUIGenerateImageForm,
ComfyUIWorkflow, ComfyUIWorkflow,
@ -16,48 +28,36 @@ from open_webui.utils.images.comfyui import (
) )
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None, router = APIRouter()
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
@app.get("/config") @router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
"enabled": app.state.config.ENABLED, "enabled": request.app.state.config.ENABLED,
"engine": app.state.config.ENGINE, "engine": request.app.state.config.ENGINE,
"openai": { "openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
}, },
"automatic1111": { "automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
}, },
"comfyui": { "comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}, },
} }
@ -89,133 +89,150 @@ class ConfigForm(BaseModel):
comfyui: ComfyUIConfigForm comfyui: ComfyUIConfigForm
@app.post("/config/update") @router.post("/config/update")
async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): async def update_config(
app.state.config.ENGINE = form_data.engine request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
app.state.config.ENABLED = form_data.enabled ):
request.app.state.config.ENGINE = form_data.engine
request.app.state.config.ENABLED = form_data.enabled
app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
app.state.config.AUTOMATIC1111_BASE_URL = ( request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL form_data.automatic1111.AUTOMATIC1111_BASE_URL
) )
app.state.config.AUTOMATIC1111_API_AUTH = ( request.app.state.config.AUTOMATIC1111_API_AUTH = (
form_data.automatic1111.AUTOMATIC1111_API_AUTH form_data.automatic1111.AUTOMATIC1111_API_AUTH
) )
app.state.config.AUTOMATIC1111_CFG_SCALE = ( request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
else None else None
) )
app.state.config.AUTOMATIC1111_SAMPLER = ( request.app.state.config.AUTOMATIC1111_SAMPLER = (
form_data.automatic1111.AUTOMATIC1111_SAMPLER form_data.automatic1111.AUTOMATIC1111_SAMPLER
if form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER
else None else None
) )
app.state.config.AUTOMATIC1111_SCHEDULER = ( request.app.state.config.AUTOMATIC1111_SCHEDULER = (
form_data.automatic1111.AUTOMATIC1111_SCHEDULER form_data.automatic1111.AUTOMATIC1111_SCHEDULER
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
else None else None
) )
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") request.app.state.config.COMFYUI_BASE_URL = (
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW form_data.comfyui.COMFYUI_BASE_URL.strip("/")
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES )
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
)
return { return {
"enabled": app.state.config.ENABLED, "enabled": request.app.state.config.ENABLED,
"engine": app.state.config.ENGINE, "engine": request.app.state.config.ENGINE,
"openai": { "openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
}, },
"automatic1111": { "automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
}, },
"comfyui": { "comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}, },
} }
def get_automatic1111_api_auth(): def get_automatic1111_api_auth(request: Request):
if app.state.config.AUTOMATIC1111_API_AUTH is None: if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return "" return ""
else: else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}" return f"Basic {auth1111_base64_encoded_string}"
@app.get("/config/url/verify") @router.get("/config/url/verify")
async def verify_url(user=Depends(get_admin_user)): async def verify_url(request: Request, user=Depends(get_admin_user)):
if app.state.config.ENGINE == "automatic1111": if request.app.state.config.ENGINE == "automatic1111":
try: try:
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()}, headers={"authorization": get_automatic1111_api_auth(request)},
) )
r.raise_for_status() r.raise_for_status()
return True return True
except Exception: except Exception:
app.state.config.ENABLED = False request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif app.state.config.ENGINE == "comfyui": elif request.app.state.config.ENGINE == "comfyui":
try: try:
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
r.raise_for_status() r.raise_for_status()
return True return True
except Exception: except Exception:
app.state.config.ENABLED = False request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else: else:
return True return True
def set_image_model(model: str): def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}") log.info(f"Setting image model to {model}")
app.state.config.MODEL = model request.app.state.config.MODEL = model
if app.state.config.ENGINE in ["", "automatic1111"]: if request.app.state.config.ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth() api_auth = get_automatic1111_api_auth()
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth}, headers={"authorization": api_auth},
) )
options = r.json() options = r.json()
if model != options["sd_model_checkpoint"]: if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model options["sd_model_checkpoint"] = model
r = requests.post( r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options, json=options,
headers={"authorization": api_auth}, headers={"authorization": api_auth},
) )
return app.state.config.MODEL return request.app.state.config.MODEL
def get_image_model(): def get_image_model():
if app.state.config.ENGINE == "openai": if request.app.state.config.ENGINE == "openai":
return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" return (
elif app.state.config.ENGINE == "comfyui": request.app.state.config.MODEL
return app.state.config.MODEL if app.state.config.MODEL else "" if request.app.state.config.MODEL
elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": else "dall-e-2"
)
elif request.app.state.config.ENGINE == "comfyui":
return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
elif (
request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
):
try: try:
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()}, headers={"authorization": get_automatic1111_api_auth()},
) )
options = r.json() options = r.json()
return options["sd_model_checkpoint"] return options["sd_model_checkpoint"]
except Exception as e: except Exception as e:
app.state.config.ENABLED = False request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -225,23 +242,25 @@ class ImageConfigForm(BaseModel):
IMAGE_STEPS: int IMAGE_STEPS: int
@app.get("/image/config") @router.get("/image/config")
async def get_image_config(user=Depends(get_admin_user)): async def get_image_config(request: Request, user=Depends(get_admin_user)):
return { return {
"MODEL": app.state.config.MODEL, "MODEL": request.app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
} }
@app.post("/image/config/update") @router.post("/image/config/update")
async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(form_data.MODEL) set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$" pattern = r"^\d+x\d+$"
if re.match(pattern, form_data.IMAGE_SIZE): if re.match(pattern, form_data.IMAGE_SIZE):
app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -249,7 +268,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
) )
if form_data.IMAGE_STEPS >= 0: if form_data.IMAGE_STEPS >= 0:
app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -257,29 +276,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
) )
return { return {
"MODEL": app.state.config.MODEL, "MODEL": request.app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
} }
@app.get("/models") @router.get("/models")
def get_models(user=Depends(get_verified_user)): def get_models(request: Request, user=Depends(get_verified_user)):
try: try:
if app.state.config.ENGINE == "openai": if request.app.state.config.ENGINE == "openai":
return [ return [
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.config.ENGINE == "comfyui": elif request.app.state.config.ENGINE == "comfyui":
# TODO - get models from comfyui # TODO - get models from comfyui
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
info = r.json() info = r.json()
workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None model_node_id = None
for node in app.state.config.COMFYUI_WORKFLOW_NODES: for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model": if node["type"] == "model":
if node["node_ids"]: if node["node_ids"]:
model_node_id = node["node_ids"][0] model_node_id = node["node_ids"][0]
@ -315,10 +336,11 @@ def get_models(user=Depends(get_verified_user)):
) )
) )
elif ( elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
): ):
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()}, headers={"authorization": get_automatic1111_api_auth()},
) )
models = r.json() models = r.json()
@ -329,7 +351,7 @@ def get_models(user=Depends(get_verified_user)):
) )
) )
except Exception as e: except Exception as e:
app.state.config.ENABLED = False request.app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -401,18 +423,21 @@ def save_url_image(url):
return None return None
@app.post("/generations") @router.post("/generations")
async def image_generations( async def image_generations(
request: Request,
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
r = None r = None
try: try:
if app.state.config.ENGINE == "openai": if request.app.state.config.ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Authorization"] = (
f"Bearer {request.app.state.config.OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS: if ENABLE_FORWARD_USER_INFO_HEADERS:
@ -423,14 +448,16 @@ async def image_generations(
data = { data = {
"model": ( "model": (
app.state.config.MODEL request.app.state.config.MODEL
if app.state.config.MODEL != "" if request.app.state.config.MODEL != ""
else "dall-e-2" else "dall-e-2"
), ),
"prompt": form_data.prompt, "prompt": form_data.prompt,
"n": form_data.n, "n": form_data.n,
"size": ( "size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
), ),
"response_format": "b64_json", "response_format": "b64_json",
} }
@ -438,7 +465,7 @@ async def image_generations(
# Use asyncio.to_thread for the requests.post call # Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread( r = await asyncio.to_thread(
requests.post, requests.post,
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data, json=data,
headers=headers, headers=headers,
) )
@ -458,7 +485,7 @@ async def image_generations(
return images return images
elif app.state.config.ENGINE == "comfyui": elif request.app.state.config.ENGINE == "comfyui":
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
"width": width, "width": width,
@ -466,8 +493,8 @@ async def image_generations(
"n": form_data.n, "n": form_data.n,
} }
if app.state.config.IMAGE_STEPS is not None: if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
@ -476,18 +503,18 @@ async def image_generations(
**{ **{
"workflow": ComfyUIWorkflow( "workflow": ComfyUIWorkflow(
**{ **{
"workflow": app.state.config.COMFYUI_WORKFLOW, "workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": app.state.config.COMFYUI_WORKFLOW_NODES, "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
} }
), ),
**data, **data,
} }
) )
res = await comfyui_generate_image( res = await comfyui_generate_image(
app.state.config.MODEL, request.app.state.config.MODEL,
form_data, form_data,
user.id, user.id,
app.state.config.COMFYUI_BASE_URL, request.app.state.config.COMFYUI_BASE_URL,
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
@ -504,7 +531,8 @@ async def image_generations(
log.debug(f"images: {images}") log.debug(f"images: {images}")
return images return images
elif ( elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" request.app.state.config.ENGINE == "automatic1111"
or request.app.state.config.ENGINE == ""
): ):
if form_data.model: if form_data.model:
set_image_model(form_data.model) set_image_model(form_data.model)
@ -516,25 +544,25 @@ async def image_generations(
"height": height, "height": height,
} }
if app.state.config.IMAGE_STEPS is not None: if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
if app.state.config.AUTOMATIC1111_CFG_SCALE: if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
if app.state.config.AUTOMATIC1111_SAMPLER: if request.app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
if app.state.config.AUTOMATIC1111_SCHEDULER: if request.app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
# Use asyncio.to_thread for the requests.post call # Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread( r = await asyncio.to_thread(
requests.post, requests.post,
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data, json=data,
headers={"authorization": get_automatic1111_api_auth()}, headers={"authorization": get_automatic1111_api_auth()},
) )

View File

@ -1,94 +0,0 @@
import inspect
import json
import logging
import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.socket.main import get_event_call, get_event_emitter
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.routers import (
auths,
chats,
folders,
configs,
groups,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
)
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
MODEL_ORDER_LIST,
ENABLE_COMMUNITY_SHARING,
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_API_KEY,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
ENABLE_LDAP,
LDAP_SERVER_LABEL,
LDAP_SERVER_HOST,
LDAP_SERVER_PORT,
LDAP_ATTRIBUTE_FOR_USERNAME,
LDAP_SEARCH_FILTERS,
LDAP_SEARCH_BASE,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
LDAP_USE_TLS,
LDAP_CA_CERT_FILE,
LDAP_CIPHERS,
AppConfig,
)
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.tools import get_tools
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])