This commit is contained in:
Timothy Jaeryang Baek
2025-04-05 04:05:52 -06:00
parent 0f310b3509
commit 0c0505e1cd
12 changed files with 613 additions and 368 deletions

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
from typing import Optional
@@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config
from open_webui.config import BannerModel
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
router = APIRouter()
@@ -66,6 +68,73 @@ async def set_direct_connections_config(
}
############################
# ToolServers Config
############################
class ToolServerConnection(BaseModel):
url: str
path: str
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
model_config = ConfigDict(extra="allow")
class ToolServersConfigForm(BaseModel):
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
@router.get("/tool_servers", response_model=ToolServersConfigForm)
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers", response_model=ToolServersConfigForm)
async def set_tool_servers_config(
request: Request,
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.TOOL_SERVER_CONNECTIONS = form_data.TOOL_SERVER_CONNECTIONS
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers/verify")
async def verify_tool_servers_config(
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
):
"""
Verify the connection to the tool server.
"""
try:
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
url = f"{form_data.url}/{form_data.path}"
return await get_tool_server_data(token, url)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to connect to the tool server: {str(e)}",
)
############################
# CodeInterpreterConfig
############################

View File

@@ -18,6 +18,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -30,7 +32,17 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(user=Depends(get_verified_user)):
async def get_tools(request: Request, user=Depends(get_verified_user)):
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
if user.role == "admin":
tools = Tools.get_tools()
else: