mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:dev' into dev
This commit is contained in:
committed by
GitHub
commit
a713e14db8
@@ -199,6 +199,7 @@ CHANGELOG = changelog_json
|
||||
|
||||
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
|
||||
|
||||
|
||||
####################################
|
||||
# ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
####################################
|
||||
@@ -272,15 +273,13 @@ if "postgres://" in DATABASE_URL:
|
||||
|
||||
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
|
||||
|
||||
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
|
||||
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
|
||||
|
||||
if DATABASE_POOL_SIZE == "":
|
||||
DATABASE_POOL_SIZE = 0
|
||||
else:
|
||||
if DATABASE_POOL_SIZE != None:
|
||||
try:
|
||||
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
|
||||
except Exception:
|
||||
DATABASE_POOL_SIZE = 0
|
||||
DATABASE_POOL_SIZE = None
|
||||
|
||||
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
|
||||
|
||||
@@ -396,6 +395,10 @@ WEBUI_AUTH_COOKIE_SECURE = (
|
||||
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
||||
|
||||
ENABLE_COMPRESSION_MIDDLEWARE = (
|
||||
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
ENABLE_WEBSOCKET_SUPPORT = (
|
||||
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -62,6 +62,9 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
log.warning(
|
||||
"Hint: If your database password contains special characters, you may need to URL-encode it."
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Properly closing the database connection
|
||||
@@ -81,20 +84,23 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
else:
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
pool_size=DATABASE_POOL_SIZE,
|
||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||
pool_timeout=DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=DATABASE_POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
if isinstance(DATABASE_POOL_SIZE, int):
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
pool_size=DATABASE_POOL_SIZE,
|
||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||
pool_timeout=DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=DATABASE_POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
|
||||
@@ -411,6 +411,7 @@ from open_webui.env import (
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
ENABLE_COMPRESSION_MIDDLEWARE,
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
RESET_CONFIG_ON_START,
|
||||
@@ -1072,7 +1073,9 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
# Add the middleware to the app
|
||||
app.add_middleware(CompressMiddleware)
|
||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||
app.add_middleware(CompressMiddleware)
|
||||
|
||||
app.add_middleware(RedirectMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
@@ -303,8 +303,14 @@ async def update_image_config(
|
||||
):
|
||||
set_image_model(request, form_data.MODEL)
|
||||
|
||||
if (form_data.IMAGE_SIZE == "auto" and form_data.MODEL != 'gpt-image-1'):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (auto is only allowed with gpt-image-1).")
|
||||
)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
if re.match(pattern, form_data.IMAGE_SIZE):
|
||||
if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
|
||||
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -472,7 +478,14 @@ async def image_generations(
|
||||
form_data: GenerateImageForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
|
||||
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
|
||||
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
|
||||
# image model other than gpt-image-1, which is warned about on settings save
|
||||
width, height = (
|
||||
tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
|
||||
if 'x' in request.app.state.config.IMAGE_SIZE
|
||||
else (512, 512)
|
||||
)
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -633,13 +633,7 @@ async def verify_connection(
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
|
||||
def convert_to_azure_payload(
|
||||
url,
|
||||
payload: dict,
|
||||
):
|
||||
model = payload.get("model", "")
|
||||
|
||||
# Filter allowed parameters based on Azure OpenAI API
|
||||
def get_azure_allowed_params(api_version: str) -> set[str]:
|
||||
allowed_params = {
|
||||
"messages",
|
||||
"temperature",
|
||||
@@ -669,6 +663,23 @@ def convert_to_azure_payload(
|
||||
"max_completion_tokens",
|
||||
}
|
||||
|
||||
try:
|
||||
if api_version >= "2024-09-01-preview":
|
||||
allowed_params.add("stream_options")
|
||||
except ValueError:
|
||||
log.debug(
|
||||
f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
|
||||
)
|
||||
|
||||
return allowed_params
|
||||
|
||||
|
||||
def convert_to_azure_payload(url, payload: dict, api_version: str):
|
||||
model = payload.get("model", "")
|
||||
|
||||
# Filter allowed parameters based on Azure OpenAI API
|
||||
allowed_params = get_azure_allowed_params(api_version)
|
||||
|
||||
# Special handling for o-series models
|
||||
if model.startswith("o") and model.endswith("-mini"):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
@@ -817,8 +828,8 @@ async def generate_chat_completion(
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
request_url, payload = convert_to_azure_payload(url, payload)
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||
request_url, payload = convert_to_azure_payload(url, payload, api_version)
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = api_version
|
||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||
@@ -1007,16 +1018,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = (
|
||||
api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
)
|
||||
headers["api-version"] = api_version
|
||||
|
||||
payload = json.loads(body)
|
||||
url, payload = convert_to_azure_payload(url, payload)
|
||||
url, payload = convert_to_azure_payload(url, payload, api_version)
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
|
||||
request_url = f"{url}/{path}?api-version={api_version}"
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
request_url = f"{url}/{path}"
|
||||
|
||||
@@ -804,7 +804,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
raise e
|
||||
|
||||
try:
|
||||
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
@@ -1741,7 +1740,7 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
async def stream_body_handler(response):
|
||||
async def stream_body_handler(response, form_data):
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
|
||||
@@ -1770,7 +1769,7 @@ async def process_chat_response(
|
||||
filter_functions=filter_functions,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
extra_params={"__body__": form_data, **extra_params},
|
||||
)
|
||||
|
||||
if data:
|
||||
@@ -2032,7 +2031,7 @@ async def process_chat_response(
|
||||
if response.background:
|
||||
await response.background()
|
||||
|
||||
await stream_body_handler(response)
|
||||
await stream_body_handler(response, form_data)
|
||||
|
||||
MAX_TOOL_CALL_RETRIES = 10
|
||||
tool_call_retries = 0
|
||||
@@ -2181,22 +2180,24 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
try:
|
||||
new_form_data = {
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
],
|
||||
}
|
||||
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
],
|
||||
},
|
||||
new_form_data,
|
||||
user,
|
||||
)
|
||||
|
||||
if isinstance(res, StreamingResponse):
|
||||
await stream_body_handler(res)
|
||||
await stream_body_handler(res, new_form_data)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
|
||||
@@ -101,9 +101,6 @@ def get_tools(
|
||||
|
||||
def make_tool_function(function_name, token, tool_server_data):
|
||||
async def tool_function(**kwargs):
|
||||
print(
|
||||
f"Executing tool function {function_name} with params: {kwargs}"
|
||||
)
|
||||
return await execute_tool_server(
|
||||
token=token,
|
||||
url=tool_server_data["url"],
|
||||
|
||||
Reference in New Issue
Block a user