From 9747a0e1f1d3612dd42ac0c50597760bc733b19e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 5 Apr 2025 04:40:01 -0600 Subject: [PATCH] refac: tool servers --- backend/open_webui/routers/tools.py | 39 +++++++++- backend/open_webui/utils/tools.py | 73 ++++++++++--------- src/lib/components/AddServerModal.svelte | 21 +++++- .../chat/Settings/Tools/Connection.svelte | 2 +- 4 files changed, 95 insertions(+), 40 deletions(-) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index ace517daa..8a98b4e20 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,6 +1,7 @@ import logging from pathlib import Path from typing import Optional +import time from open_webui.models.tools import ( ToolForm, @@ -43,10 +44,40 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): request.app.state.config.TOOL_SERVER_CONNECTIONS ) - if user.role == "admin": - tools = Tools.get_tools() - else: - tools = Tools.get_tools_by_user_id(user.id, "read") + tools = Tools.get_tools() + for idx, server in enumerate(request.app.state.TOOL_SERVERS): + tools.append( + ToolUserResponse( + **{ + "id": f"server:{server['idx']}", + "user_id": f"server:{server['idx']}", + "name": server["openapi"] + .get("info", {}) + .get("title", "Tool Server"), + "meta": { + "description": server["openapi"] + .get("info", {}) + .get("description", ""), + }, + "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ + idx + ] + .get("config", {}) + .get("access_control", None), + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + ) + + if user.role != "admin": + tools = [ + tool + for tool in tools + if tool.user_id == user.id + or has_access(user.id, "read", tool.access_control) + ] + return tools diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 425a95081..84def7060 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -5,7 +5,7 @@ import inspect import aiohttp import asyncio -from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union +from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional from functools import update_wrapper, partial @@ -348,40 +348,47 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: return data -async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - enabled_servers = [ - server for server in servers if server.get("config", {}).get("enable") - ] - - urls = [ - ( - server, - f"{server.get('url')}/{server.get('path', 'openapi.json')}", - server.get("key", ""), - ) - for server in enabled_servers - ] - - tasks = [get_tool_server_data(token, url) for _, url, token in urls] - - results: List[Dict[str, Any]] = [] - - responses = await asyncio.gather(*tasks, return_exceptions=True) - - for (server, _, _), response in zip(urls, responses): - if isinstance(response, Exception): +async def get_tool_servers_data( + servers: List[Dict[str, Any]], session_token: Optional[str] = None +) -> List[Dict[str, Any]]: + # Prepare list of enabled servers along with their original index + server_entries = [] + for idx, server in enumerate(servers): + if server.get("config", {}).get("enable"): url_path = server.get("path", "openapi.json") full_url = f"{server.get('url')}/{url_path}" - print(f"Failed to connect to {full_url} OpenAPI tool server") - else: - results.append( - { - "url": server.get("url"), - "openapi": response["openapi"], - "info": response["info"], - "specs": response["specs"], - } - ) + + auth_type = server.get("auth_type", "bearer") + token = None + + if auth_type == "bearer": + token = server.get("key", "") + elif auth_type == "session": + token = session_token + server_entries.append((idx, server, full_url, token)) + + # Create async tasks to fetch data + tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries] + + # Execute tasks concurrently + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Build final results with index and server metadata + results = [] + for (idx, server, url, _), response in zip(server_entries, responses): + if isinstance(response, Exception): + print(f"Failed to connect to {url} OpenAPI tool server") + continue + + results.append( + { + "idx": idx, + "url": server.get("url"), + "openapi": response.get("openapi"), + "info": response.get("info"), + "specs": response.get("specs"), + } + ) return results diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index 4e8747da9..0ca1063d8 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -17,6 +17,7 @@ import Tags from './common/Tags.svelte'; import { getToolServerData } from '$lib/apis'; import { verifyToolServerConnection } from '$lib/apis/configs'; + import AccessControl from './workspace/common/AccessControl.svelte'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -34,6 +35,8 @@ let auth_type = 'bearer'; let key = ''; + let accessControl = null; + let enable = true; let loading = false; @@ -68,7 +71,8 @@ auth_type, key, config: { - enable: enable + enable: enable, + access_control: accessControl } }).catch((err) => { toast.error($i18n.t('Connection failed')); @@ -93,7 +97,8 @@ auth_type, key, config: { - enable: enable + enable: enable, + access_control: accessControl } }; @@ -108,6 +113,7 @@ auth_type = 'bearer'; enable = true; + accessControl = null; }; const init = () => { @@ -119,6 +125,7 @@ key = connection?.key ?? ''; enable = connection.config?.enable ?? true; + accessControl = connection.config?.access_control ?? null; } }; @@ -269,6 +276,16 @@ + + {#if !direct} +
+ +
+
+ +
+
+ {/if}
diff --git a/src/lib/components/chat/Settings/Tools/Connection.svelte b/src/lib/components/chat/Settings/Tools/Connection.svelte index 4f9f7f1a2..5ca5a4521 100644 --- a/src/lib/components/chat/Settings/Tools/Connection.svelte +++ b/src/lib/components/chat/Settings/Tools/Connection.svelte @@ -20,7 +20,7 @@ {