diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index f726368eb..d32a7cce3 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,29 +1,25 @@ import logging -from pathlib import Path -from typing import Optional -import time import re -import aiohttp -from pydantic import BaseModel, HttpUrl +import time +from typing import Optional +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import SRC_LOG_LEVELS from open_webui.models.tools import ( ToolForm, ToolModel, ToolResponse, - ToolUserResponse, Tools, + ToolUserResponse, ) -from open_webui.utils.plugin import load_tool_module_by_id, replace_imports -from open_webui.config import CACHE_DIR -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.tools import get_tool_specs -from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission -from open_webui.env import SRC_LOG_LEVELS - -from open_webui.utils.tools import get_tool_servers_data - +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.plugin import load_tool_module_by_id, replace_imports +from open_webui.utils.tools import get_tool_servers_data, get_tool_specs +from pydantic import BaseModel, HttpUrl log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -53,21 +49,25 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): tools.append( ToolUserResponse( **{ - "id": f"server:{server['idx']}", - "user_id": f"server:{server['idx']}", - "name": server.get("openapi", {}) - .get("info", {}) - .get("title", "Tool Server"), - "meta": { - "description": server.get("openapi", {}) + "id": f"server:{server['hash']}", + "user_id": f"server:{server['hash']}", + "name": ( + server.get("openapi", {}) .get("info", {}) - .get("description", ""), + .get("title", "Tool Server") + ), + "meta": { + "description": ( + server.get("openapi", {}) + .get("info", {}) + .get("description", "") + ), }, - "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ - server["idx"] - ] - .get("config", {}) - .get("access_control", None), + "access_control": ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]] + .get("config", {}) + .get("access_control", None) + ), "updated_at": int(time.time()), "created_at": int(time.time()), } diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index dda2635ec..879aaae01 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,48 +1,26 @@ +import asyncio +import copy import inspect import logging import re -import inspect +from functools import partial, update_wrapper +from typing import Any, Awaitable, Callable, Dict, List, Optional, get_type_hints + import aiohttp -import asyncio import yaml - -from pydantic import BaseModel -from pydantic.fields import FieldInfo -from typing import ( - Any, - Awaitable, - Callable, - get_type_hints, - get_args, - get_origin, - Dict, - List, - Tuple, - Union, - Optional, - Type, -) -from functools import update_wrapper, partial - - from fastapi import Request -from pydantic import BaseModel, Field, create_model - from langchain_core.utils.function_calling import ( convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, ) - - +from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, + AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, + SRC_LOG_LEVELS, +) from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tool_module_by_id -from open_webui.env import ( - SRC_LOG_LEVELS, - AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, - AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, -) - -import copy +from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -76,16 +54,18 @@ def get_tools( tool = Tools.get_tool_by_id(tool_id) if tool is None: if tool_id.startswith("server:"): - server_idx = int(tool_id.split(":")[1]) - tool_server_connection = ( - request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx] - ) + server_hash = tool_id.split(":", 1)[1] # Get hash part after "server:" tool_server_data = None for server in request.app.state.TOOL_SERVERS: - if server["idx"] == server_idx: + if server["hash"] == server_hash: tool_server_data = server break assert tool_server_data is not None + + # Get connection config directly from server data + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]] + ) specs = tool_server_data.get("specs", []) for spec in specs: @@ -102,7 +82,8 @@ def get_tools( def make_tool_function(function_name, token, tool_server_data): async def tool_function(**kwargs): print( - f"Executing tool function {function_name} with params: {kwargs}" + f"Executing tool function {function_name} with params:" + f" {kwargs}" ) return await execute_tool_server( token=token, @@ -190,8 +171,9 @@ def get_tools( "spec": spec, # Misc info "metadata": { - "file_handler": hasattr(module, "file_handler") - and module.file_handler, + "file_handler": ( + hasattr(module, "file_handler") and module.file_handler + ), "citation": hasattr(module, "citation") and module.citation, }, } @@ -440,6 +422,26 @@ def convert_openapi_to_tool_payload(openapi_spec): return tool_payload +def compute_server_hash(server_url: str, server_path: str) -> str: + """ + Compute a hash for a tool server based on its URL and path. + + Creates a consistent string representation by combining the normalized URL and path, + then uses Python's built-in hash() function for performance. The result is converted + to a positive integer and formatted as an 8-digit hexadecimal string for consistent + length and use as a server identifier. + + Args: + server_url: The base URL of the server + server_path: The path to the OpenAPI spec + + Returns: + A hash string to use as the server identifier + """ + combined = f"{server_url.rstrip('/')}/{server_path.lstrip('/')}" + return f"{hash(combined) & 0x7FFFFFFF:08x}" + + async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: headers = { "Accept": "application/json", @@ -486,7 +488,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: async def get_tool_servers_data( servers: List[Dict[str, Any]], session_token: Optional[str] = None ) -> List[Dict[str, Any]]: - # Prepare list of enabled servers along with their original index + # Prepare list of enabled servers along with their hash-based identifier server_entries = [] for idx, server in enumerate(servers): if server.get("config", {}).get("enable"): @@ -504,6 +506,8 @@ async def get_tool_servers_data( info = server.get("info", {}) + server_hash = compute_server_hash(server.get("url", ""), openapi_path) + auth_type = server.get("auth_type", "bearer") token = None @@ -511,19 +515,21 @@ async def get_tool_servers_data( token = server.get("key", "") elif auth_type == "session": token = session_token - server_entries.append((idx, server, full_url, info, token)) + server_entries.append((idx, server_hash, server, full_url, info, token)) # Create async tasks to fetch data tasks = [ - get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries + get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries ] # Execute tasks concurrently responses = await asyncio.gather(*tasks, return_exceptions=True) - # Build final results with index and server metadata + # Build final results with ID, hash, url, openapi, info, and specs results = [] - for (idx, server, url, info, _), response in zip(server_entries, responses): + for (idx, server_hash, server, url, info, _), response in zip( + server_entries, responses + ): if isinstance(response, Exception): log.error(f"Failed to connect to {url} OpenAPI tool server") continue @@ -540,6 +546,7 @@ async def get_tool_servers_data( results.append( { "idx": idx, + "hash": server_hash, # Hash-based ID "url": server.get("url"), "openapi": openapi_data, "info": response.get("info"),