mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into fix-12237
This commit is contained in:
@@ -194,8 +194,8 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
ciphers=LDAP_CIPHERS,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred on TLS: {str(e)}")
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"TLS configuration error: {str(e)}")
|
||||
raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
|
||||
|
||||
try:
|
||||
server = Server(
|
||||
@@ -232,7 +232,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not email or email == "" or email == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have email.")
|
||||
raise HTTPException(400, "User does not have a valid email address.")
|
||||
else:
|
||||
email = email.lower()
|
||||
|
||||
@@ -248,7 +248,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
authentication="SIMPLE",
|
||||
)
|
||||
if not connection_user.bind():
|
||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||
raise HTTPException(400, "Authentication failed.")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
@@ -276,7 +276,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"LDAP user creation error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="Internal error occurred during LDAP user creation."
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(email)
|
||||
|
||||
@@ -312,12 +315,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
|
||||
)
|
||||
raise HTTPException(400, "User record mismatch.")
|
||||
except Exception as e:
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"LDAP authentication error: {str(e)}")
|
||||
raise HTTPException(400, detail="LDAP authentication failed.")
|
||||
|
||||
|
||||
############################
|
||||
@@ -519,7 +520,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Signup error: {str(err)}")
|
||||
raise HTTPException(500, detail="An internal error occurred during signup.")
|
||||
|
||||
|
||||
@router.get("/signout")
|
||||
@@ -547,7 +549,11 @@ async def signout(request: Request, response: Response):
|
||||
detail="Failed to fetch OpenID configuration",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
log.error(f"OpenID signout error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to sign out from the OpenID provider.",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
@@ -591,7 +597,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Add user error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="An internal error occurred while adding the user."
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.config import get_config, save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,6 +68,75 @@ async def set_direct_connections_config(
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# ToolServers Config
|
||||
############################
|
||||
|
||||
|
||||
class ToolServerConnection(BaseModel):
|
||||
url: str
|
||||
path: str
|
||||
auth_type: Optional[str]
|
||||
key: Optional[str]
|
||||
config: Optional[dict]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolServersConfigForm(BaseModel):
|
||||
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
|
||||
|
||||
|
||||
@router.get("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def set_tool_servers_config(
|
||||
request: Request,
|
||||
form_data: ToolServersConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers/verify")
|
||||
async def verify_tool_servers_config(
|
||||
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Verify the connection to the tool server.
|
||||
"""
|
||||
try:
|
||||
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
|
||||
url = f"{form_data.url}/{form_data.path}"
|
||||
return await get_tool_server_data(token, url)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to the tool server: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
|
||||
@@ -1197,7 +1197,7 @@ class OpenAIChatMessageContent(BaseModel):
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, list[OpenAIChatMessageContent]]
|
||||
content: Union[Optional[str], list[OpenAIChatMessageContent]]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -1534,8 +1534,13 @@ def query_doc_handler(
|
||||
):
|
||||
try:
|
||||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
collection_results = {}
|
||||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=form_data.collection_name
|
||||
)
|
||||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
collection_result=collection_results[form_data.collection_name],
|
||||
query=form_data.query,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, prefix=prefix, user=user
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
@@ -18,6 +19,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
@@ -30,11 +33,51 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read")
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.TOOL_SERVERS:
|
||||
# If the tool servers are not set, we need to set them
|
||||
# This is done only once when the server starts
|
||||
# This is done to avoid loading the tool servers every time
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
for idx, server in enumerate(request.app.state.TOOL_SERVERS):
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"name": server["openapi"]
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
"meta": {
|
||||
"description": server["openapi"]
|
||||
.get("info", {})
|
||||
.get("description", ""),
|
||||
},
|
||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
||||
idx
|
||||
]
|
||||
.get("config", {})
|
||||
.get("access_control", None),
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if user.role != "admin":
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control)
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user