Merge branch 'open-webui:main' into fix-12237

This commit is contained in:
Juan Calderon-Perez
2025-04-06 13:30:37 -04:00
committed by GitHub
93 changed files with 3815 additions and 778 deletions

View File

@@ -194,8 +194,8 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
ciphers=LDAP_CIPHERS,
)
except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}")
raise HTTPException(400, detail=str(e))
log.error(f"TLS configuration error: {str(e)}")
raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
try:
server = Server(
@@ -232,7 +232,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
if not email or email == "" or email == "[]":
raise HTTPException(400, f"User {form_data.user} does not have email.")
raise HTTPException(400, "User does not have a valid email address.")
else:
email = email.lower()
@@ -248,7 +248,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
authentication="SIMPLE",
)
if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}")
raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(email)
if not user:
@@ -276,7 +276,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
except HTTPException:
raise
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"LDAP user creation error: {str(err)}")
raise HTTPException(
500, detail="Internal error occurred during LDAP user creation."
)
user = Auths.authenticate_user_by_trusted_header(email)
@@ -312,12 +315,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
raise HTTPException(
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
raise HTTPException(400, "User record mismatch.")
except Exception as e:
raise HTTPException(400, detail=str(e))
log.error(f"LDAP authentication error: {str(e)}")
raise HTTPException(400, detail="LDAP authentication failed.")
############################
@@ -519,7 +520,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"Signup error: {str(err)}")
raise HTTPException(500, detail="An internal error occurred during signup.")
@router.get("/signout")
@@ -547,7 +549,11 @@ async def signout(request: Request, response: Response):
detail="Failed to fetch OpenID configuration",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
log.error(f"OpenID signout error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Failed to sign out from the OpenID provider.",
)
return {"status": True}
@@ -591,7 +597,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"Add user error: {str(err)}")
raise HTTPException(
500, detail="An internal error occurred while adding the user."
)
############################

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
from typing import Optional
@@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config
from open_webui.config import BannerModel
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
router = APIRouter()
@@ -66,6 +68,75 @@ async def set_direct_connections_config(
}
############################
# ToolServers Config
############################
class ToolServerConnection(BaseModel):
url: str
path: str
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
model_config = ConfigDict(extra="allow")
class ToolServersConfigForm(BaseModel):
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
@router.get("/tool_servers", response_model=ToolServersConfigForm)
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers", response_model=ToolServersConfigForm)
async def set_tool_servers_config(
request: Request,
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers/verify")
async def verify_tool_servers_config(
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
):
"""
Verify the connection to the tool server.
"""
try:
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
url = f"{form_data.url}/{form_data.path}"
return await get_tool_server_data(token, url)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to connect to the tool server: {str(e)}",
)
############################
# CodeInterpreterConfig
############################

View File

@@ -1197,7 +1197,7 @@ class OpenAIChatMessageContent(BaseModel):
class OpenAIChatMessage(BaseModel):
role: str
content: Union[str, list[OpenAIChatMessageContent]]
content: Union[Optional[str], list[OpenAIChatMessageContent]]
model_config = ConfigDict(extra="allow")

View File

@@ -1534,8 +1534,13 @@ def query_doc_handler(
):
try:
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
collection_results = {}
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name
)
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name],
query=form_data.query,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user

View File

@@ -1,6 +1,7 @@
import logging
from pathlib import Path
from typing import Optional
import time
from open_webui.models.tools import (
ToolForm,
@@ -18,6 +19,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -30,11 +33,51 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "read")
async def get_tools(request: Request, user=Depends(get_verified_user)):
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools()
for idx, server in enumerate(request.app.state.TOOL_SERVERS):
tools.append(
ToolUserResponse(
**{
"id": f"server:{server['idx']}",
"user_id": f"server:{server['idx']}",
"name": server["openapi"]
.get("info", {})
.get("title", "Tool Server"),
"meta": {
"description": server["openapi"]
.get("info", {})
.get("description", ""),
},
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
idx
]
.get("config", {})
.get("access_control", None),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role != "admin":
tools = [
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control)
]
return tools