From e570a98bf76288c09d51be3d684b26ed455d4287 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 5 Apr 2025 05:31:22 -0600 Subject: [PATCH] refac: substandard codebase overhauled --- backend/open_webui/utils/tools.py | 87 ++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index c824ceec6..ffaa59bfd 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -23,21 +23,23 @@ import copy log = logging.getLogger(__name__) -def apply_extra_params_to_tool_function( +def get_async_tool_function_and_apply_extra_params( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): update_wrapper(partial_func, function) return partial_func + else: + # Make it a coroutine function + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) - async def new_function(*args, **kwargs): - return partial_func(*args, **kwargs) - - update_wrapper(new_function, function) - return new_function + update_wrapper(new_function, function) + return new_function def get_tools( @@ -48,22 +50,49 @@ def get_tools( for tool_id in tool_ids: tool = Tools.get_tool_by_id(tool_id) if tool is None: - if tool_id.startswith("server:"): server_idx = int(tool_id.split(":")[1]) + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx] + ) tool_server_data = request.app.state.TOOL_SERVERS[server_idx] + specs = tool_server_data.get("specs", []) - tool_dict = { - "spec": spec, - "callable": callable, - "tool_id": tool_id, - # Misc info - "metadata": { - "file_handler": hasattr(module, "file_handler") - and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - }, - } + for spec in specs: + function_name = spec["name"] + + auth_type = tool_server_connection.get("auth_type", "bearer") + token = None + + if auth_type == "bearer": + token = tool_server_connection.get("key", "") + elif auth_type == "session": + token = request.state.token.credentials + + callable = get_async_tool_function_and_apply_extra_params( + execute_tool_server, + { + "token": token, + "url": tool_server_data["url"], + "name": function_name, + "server_data": tool_server_data, + }, + ) + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + log.warning(f"Discarding {tool_id}.{function_name}") + else: + tools_dict[function_name] = tool_dict else: continue else: @@ -73,10 +102,11 @@ def get_tools( request.app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id + + # Set valves for the tool if hasattr(module, "valves") and hasattr(module, "Valves"): valves = Tools.get_tool_valves_by_id(tool_id) or {} module.valves = module.Valves(**valves) - if hasattr(module, "UserValves"): extra_params["__user__"]["valves"] = module.UserValves( # type: ignore **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) @@ -89,21 +119,21 @@ def get_tools( if val["type"] == "str": val["type"] = "string" - # Remove internal parameters + # Remove internal reserved parameters (e.g. __id__, __user__) spec["parameters"]["properties"] = { key: val for key, val in spec["parameters"]["properties"].items() if not key.startswith("__") } - function_name = spec["name"] - # convert to function that takes only model params and inserts custom params - original_func = getattr(module, function_name) - callable = apply_extra_params_to_tool_function( - original_func, extra_params + function_name = spec["name"] + tool_function = getattr(module, function_name) + callable = get_async_tool_function_and_apply_extra_params( + tool_function, extra_params ) + # TODO: Support Pydantic models as parameters if callable.__doc__ and callable.__doc__.strip() != "": s = re.split(":(param|return)", callable.__doc__, 1) spec["description"] = s[0] @@ -111,9 +141,9 @@ def get_tools( spec["description"] = function_name tool_dict = { - "spec": spec, - "callable": callable, "tool_id": tool_id, + "callable": callable, + "spec": spec, # Misc info "metadata": { "file_handler": hasattr(module, "file_handler") @@ -127,8 +157,7 @@ def get_tools( log.warning( f"Tool {function_name} already exists in another tools!" ) - log.warning(f"Collision between {tool} and {tool_id}.") - log.warning(f"Discarding {tool}.{function_name}") + log.warning(f"Discarding {tool_id}.{function_name}") else: tools_dict[function_name] = tool_dict