diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0a33c68b1..50898800a 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -881,6 +881,17 @@ except Exception: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" +#################################### +# TOOL_SERVERS +#################################### + + +TOOL_SERVER_CONNECTIONS = PersistentConfig( + "TOOL_SERVER_CONNECTIONS", + "tool_server.connections", + [], +) + #################################### # WEBUI #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 383523174..c9ca059c2 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -105,6 +105,8 @@ from open_webui.config import ( OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Tool Server Configs + TOOL_SERVER_CONNECTIONS, # Code Execution ENABLE_CODE_EXECUTION, CODE_EXECUTION_ENGINE, @@ -356,6 +358,7 @@ from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( get_license_data, + get_http_authorization_cred, decode_token, get_admin_user, get_verified_user, @@ -478,6 +481,15 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.OPENAI_MODELS = {} +######################################## +# +# TOOL SERVERS +# +######################################## + +app.state.config.TOOL_SERVER_CONNECTIONS = TOOL_SERVER_CONNECTIONS +app.state.TOOL_SERVERS = [] + ######################################## # # DIRECT CONNECTIONS @@ -864,6 +876,10 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) + request.state.token = get_http_authorization_cred( + request.headers.get("Authorization") + ) + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 2a4c651f2..e891b7aa8 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,5 +1,5 @@ -from fastapi import APIRouter, Depends, Request -from pydantic import BaseModel +from fastapi import APIRouter, Depends, Request, HTTPException +from pydantic import BaseModel, ConfigDict from typing import Optional @@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel +from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data + router = APIRouter() @@ -66,6 +68,73 @@ async def set_direct_connections_config( } +############################ +# ToolServers Config +############################ + + +class ToolServerConnection(BaseModel): + url: str + path: str + auth_type: Optional[str] + key: Optional[str] + config: Optional[dict] + + model_config = ConfigDict(extra="allow") + + +class ToolServersConfigForm(BaseModel): + TOOL_SERVER_CONNECTIONS: list[ToolServerConnection] + + +@router.get("/tool_servers", response_model=ToolServersConfigForm) +async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)): + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers", response_model=ToolServersConfigForm) +async def set_tool_servers_config( + request: Request, + form_data: ToolServersConfigForm, + user=Depends(get_admin_user), +): + request.app.state.config.TOOL_SERVER_CONNECTIONS = form_data.TOOL_SERVER_CONNECTIONS + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers/verify") +async def verify_tool_servers_config( + request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user) +): + """ + Verify the connection to the tool server. + """ + try: + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + + url = f"{form_data.url}/{form_data.path}" + return await get_tool_server_data(token, url) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to connect to the tool server: {str(e)}", + ) + + ############################ # CodeInterpreterConfig ############################ diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 211264cde..ace517daa 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -18,6 +18,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.tools import get_tool_servers_data + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -30,7 +32,17 @@ router = APIRouter() @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(user=Depends(get_verified_user)): +async def get_tools(request: Request, user=Depends(get_verified_user)): + + if not request.app.state.TOOL_SERVERS: + # If the tool servers are not set, we need to set them + # This is done only once when the server starts + # This is done to avoid loading the tool servers every time + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + if user.role == "admin": tools = Tools.get_tools() else: diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 6af99f164..118ac049e 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -143,12 +143,14 @@ def create_api_key(): return f"sk-{key}" -def get_http_authorization_cred(auth_header: str): +def get_http_authorization_cred(auth_header: Optional[str]): + if not auth_header: + return None try: scheme, credentials = auth_header.split(" ") return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) except Exception: - raise ValueError(ERROR_MESSAGES.INVALID_TOKEN) + return None def get_current_user( diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index bd2a731e6..425a95081 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -2,9 +2,10 @@ import inspect import logging import re import inspect -import uuid +import aiohttp +import asyncio -from typing import Any, Awaitable, Callable, get_type_hints +from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union from functools import update_wrapper, partial @@ -217,3 +218,260 @@ def get_tools_specs(tool_class: object) -> list[dict]: function_list = get_callable_attributes(tool_class) models = map(function_to_pydantic_model, function_list) return [convert_to_openai_function(tool) for tool in models] + + +import copy + + +def resolve_schema(schema, components): + """ + Recursively resolves a JSON schema using OpenAPI components. + """ + if not schema: + return {} + + if "$ref" in schema: + ref_path = schema["$ref"] + ref_parts = ref_path.strip("#/").split("/") + resolved = components + for part in ref_parts[1:]: # Skip the initial 'components' + resolved = resolved.get(part, {}) + return resolve_schema(resolved, components) + + resolved_schema = copy.deepcopy(schema) + + # Recursively resolve inner schemas + if "properties" in resolved_schema: + for prop, prop_schema in resolved_schema["properties"].items(): + resolved_schema["properties"][prop] = resolve_schema( + prop_schema, components + ) + + if "items" in resolved_schema: + resolved_schema["items"] = resolve_schema(resolved_schema["items"], components) + + return resolved_schema + + +def convert_openapi_to_tool_payload(openapi_spec): + """ + Converts an OpenAPI specification into a custom tool payload structure. + + Args: + openapi_spec (dict): The OpenAPI specification as a Python dict. + + Returns: + list: A list of tool payloads. + """ + tool_payload = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + for method, operation in methods.items(): + tool = { + "type": "function", + "name": operation.get("operationId"), + "description": operation.get("summary", "No description available."), + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + # Extract path and query parameters + for param in operation.get("parameters", []): + param_name = param["name"] + param_schema = param.get("schema", {}) + tool["parameters"]["properties"][param_name] = { + "type": param_schema.get("type"), + "description": param_schema.get("description", ""), + } + if param.get("required"): + tool["parameters"]["required"].append(param_name) + + # Extract and resolve requestBody if available + request_body = operation.get("requestBody") + if request_body: + content = request_body.get("content", {}) + json_schema = content.get("application/json", {}).get("schema") + if json_schema: + resolved_schema = resolve_schema( + json_schema, openapi_spec.get("components", {}) + ) + + if resolved_schema.get("properties"): + tool["parameters"]["properties"].update( + resolved_schema["properties"] + ) + if "required" in resolved_schema: + tool["parameters"]["required"] = list( + set( + tool["parameters"]["required"] + + resolved_schema["required"] + ) + ) + elif resolved_schema.get("type") == "array": + tool["parameters"] = resolved_schema # special case for array + + tool_payload.append(tool) + + return tool_payload + + +async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if token: + headers["Authorization"] = f"Bearer {token}" + + error = None + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + error_body = await response.json() + raise Exception(error_body) + res = await response.json() + except Exception as err: + print("Error:", err) + if isinstance(err, dict) and "detail" in err: + error = err["detail"] + else: + error = str(err) + raise Exception(error) + + data = { + "openapi": res, + "info": res.get("info", {}), + "specs": convert_openapi_to_tool_payload(res), + } + + print("Fetched data:", data) + return data + + +async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + enabled_servers = [ + server for server in servers if server.get("config", {}).get("enable") + ] + + urls = [ + ( + server, + f"{server.get('url')}/{server.get('path', 'openapi.json')}", + server.get("key", ""), + ) + for server in enabled_servers + ] + + tasks = [get_tool_server_data(token, url) for _, url, token in urls] + + results: List[Dict[str, Any]] = [] + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for (server, _, _), response in zip(urls, responses): + if isinstance(response, Exception): + url_path = server.get("path", "openapi.json") + full_url = f"{server.get('url')}/{url_path}" + print(f"Failed to connect to {full_url} OpenAPI tool server") + else: + results.append( + { + "url": server.get("url"), + "openapi": response["openapi"], + "info": response["info"], + "specs": response["specs"], + } + ) + + return results + + +async def execute_tool_server( + token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any] +) -> Any: + error = None + try: + openapi = server_data.get("openapi", {}) + paths = openapi.get("paths", {}) + + matching_route = None + for route_path, methods in paths.items(): + for http_method, operation in methods.items(): + if isinstance(operation, dict) and operation.get("operationId") == name: + matching_route = (route_path, methods) + break + if matching_route: + break + + if not matching_route: + raise Exception(f"No matching route found for operationId: {name}") + + route_path, methods = matching_route + + method_entry = None + for http_method, operation in methods.items(): + if operation.get("operationId") == name: + method_entry = (http_method.lower(), operation) + break + + if not method_entry: + raise Exception(f"No matching method found for operationId: {name}") + + http_method, operation = method_entry + + path_params = {} + query_params = {} + body_params = {} + + for param in operation.get("parameters", []): + param_name = param["name"] + param_in = param["in"] + if param_name in params: + if param_in == "path": + path_params[param_name] = params[param_name] + elif param_in == "query": + query_params[param_name] = params[param_name] + + final_url = f"{url}{route_path}" + for key, value in path_params.items(): + final_url = final_url.replace(f"{{{key}}}", str(value)) + + if query_params: + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) + final_url = f"{final_url}?{query_string}" + + if operation.get("requestBody", {}).get("content"): + if params: + body_params = params + else: + raise Exception( + f"Request body expected for operation '{name}' but none found." + ) + + headers = {"Content-Type": "application/json"} + + if token: + headers["Authorization"] = f"Bearer {token}" + + async with aiohttp.ClientSession() as session: + request_method = getattr(session, http_method.lower()) + + if http_method in ["post", "put", "patch"]: + async with request_method( + final_url, json=body_params, headers=headers + ) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + else: + async with request_method(final_url, headers=headers) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + + except Exception as err: + error = str(err) + print("API Request Error:", error) + return {"error": error} diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index f7f02c740..5872303f6 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -115,6 +115,93 @@ export const setDirectConnectionsConfig = async (token: string, config: object) return res; }; +export const getToolServerConnections = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setToolServerConnections = async (token: string, connections: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connections + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const verifyToolServerConnection = async (token: string, connection: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers/verify`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connection + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index b6d4d10d2..cdd6887b2 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -306,11 +306,11 @@ export const getToolServersData = async (i18n, servers: object[]) => { .map(async (server) => { const data = await getToolServerData( server?.key, - server?.url + (server?.path ?? '/openapi.json') + server?.url + '/' + (server?.path ?? 'openapi.json') ).catch((err) => { toast.error( i18n.t(`Failed to connect to {{URL}} OpenAPI tool server`, { - URL: server?.url + (server?.path ?? '/openapi.json') + URL: server?.url + '/' + (server?.path ?? 'openapi.json') }) ); return null; diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index 500032ba1..4e8747da9 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -15,6 +15,8 @@ import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; import Tags from './common/Tags.svelte'; + import { getToolServerData } from '$lib/apis'; + import { verifyToolServerConnection } from '$lib/apis/configs'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -22,10 +24,12 @@ export let show = false; export let edit = false; + export let direct = false; + export let connection = null; let url = ''; - let path = '/openapi.json'; + let path = 'openapi.json'; let auth_type = 'bearer'; let key = ''; @@ -34,6 +38,49 @@ let loading = false; + const verifyHandler = async () => { + if (url === '') { + toast.error($i18n.t('Please enter a valid URL')); + return; + } + + if (path === '') { + toast.error($i18n.t('Please enter a valid path')); + return; + } + + if (direct) { + const res = await getToolServerData( + auth_type === 'bearer' ? key : localStorage.token, + `${url}/${path}` + ).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } else { + const res = await verifyToolServerConnection(localStorage.token, { + url, + path, + auth_type, + key, + config: { + enable: enable + } + }).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } + }; + const submitHandler = async () => { loading = true; @@ -56,7 +103,7 @@ show = false; url = ''; - path = '/openapi.json'; + path = 'openapi.json'; key = ''; auth_type = 'bearer'; @@ -66,7 +113,7 @@ const init = () => { if (connection) { url = connection.url; - path = connection?.path ?? '/openapi.json'; + path = connection?.path ?? 'openapi.json'; auth_type = connection?.auth_type ?? 'bearer'; key = connection?.key ?? ''; @@ -125,20 +172,53 @@
-
{$i18n.t('URL')}
+
+
{$i18n.t('URL')}
+
-
+
+ + + + + + + +
-
+
+
/
- -
- - - -
- {$i18n.t(`WebUI will make requests to "{{url}}{{path}}"`, { - url: url, - path: path + {$i18n.t(`WebUI will make requests to "{{url}}"`, { + url: `${url}/${path}` })}
diff --git a/src/lib/components/admin/Settings/Tools.svelte b/src/lib/components/admin/Settings/Tools.svelte index ac0566f22..7e1c6b03e 100644 --- a/src/lib/components/admin/Settings/Tools.svelte +++ b/src/lib/components/admin/Settings/Tools.svelte @@ -1,317 +1,66 @@ - + - - -
+ { + updateHandler(); + }} +>
- {#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && directConnectionsConfig !== null} -
-
-
-
{$i18n.t('OpenAI API')}
+ {#if servers !== null} +
+
+
{$i18n.t('General')}
-
-
- { - updateOpenAIHandler(); - }} - /> -
-
-
- - {#if ENABLE_OPENAI_API} -
- -
-
-
{$i18n.t('Manage OpenAI API Connections')}
- - - - -
- -
- {#each OPENAI_API_BASE_URLS as url, idx} - { - updateOpenAIHandler(); - }} - onDelete={() => { - OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( - (url, urlIdx) => idx !== urlIdx - ); - OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); - - let newConfig = {}; - OPENAI_API_BASE_URLS.forEach((url, newIdx) => { - newConfig[newIdx] = OPENAI_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1]; - }); - OPENAI_API_CONFIGS = newConfig; - updateOpenAIHandler(); - }} - /> - {/each} -
-
- {/if} -
-
- -
- -
-
-
{$i18n.t('Ollama API')}
- -
- { - updateOllamaHandler(); - }} - /> -
-
- - {#if ENABLE_OLLAMA_API}
-
-
-
{$i18n.t('Manage Ollama API Connections')}
+
+ +
+
{$i18n.t('Manage Tool Servers')}
-
-
- {#each OLLAMA_BASE_URLS as url, idx} - { - updateOllamaHandler(); - }} - onDelete={() => { - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); +
+ {#each servers as server, idx} + { + updateHandler(); + }} + onDelete={() => { + servers = servers.filter((_, i) => i !== idx); + updateHandler(); + }} + /> + {/each} +
- let newConfig = {}; - OLLAMA_BASE_URLS.forEach((url, newIdx) => { - newConfig[newIdx] = OLLAMA_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1]; - }); - OLLAMA_API_CONFIGS = newConfig; - }} - /> - {/each} +
+
+ {$i18n.t('Connect to your own OpenAPI compatible external tool servers.')}
- -
- {$i18n.t('Trouble accessing Ollama?')} - - {$i18n.t('Click here for help.')} - -
- {/if} -
-
+
{:else} diff --git a/src/lib/components/chat/Settings/Tools.svelte b/src/lib/components/chat/Settings/Tools.svelte index bb207ce7f..86a9abbd2 100644 --- a/src/lib/components/chat/Settings/Tools.svelte +++ b/src/lib/components/chat/Settings/Tools.svelte @@ -39,7 +39,7 @@ }); - + { updateHandler(); }} diff --git a/src/lib/components/chat/Settings/Tools/Connection.svelte b/src/lib/components/chat/Settings/Tools/Connection.svelte index 27dd1758c..4f9f7f1a2 100644 --- a/src/lib/components/chat/Settings/Tools/Connection.svelte +++ b/src/lib/components/chat/Settings/Tools/Connection.svelte @@ -12,6 +12,7 @@ export let onSubmit = () => {}; export let connection = null; + export let direct = false; let showConfigModal = false; let showDeleteConfirmDialog = false; @@ -42,9 +43,8 @@