feat: use hash for remote server tools

This commit is contained in:
signorettif
2025-06-18 14:59:25 +02:00
parent b5f4c85bb1
commit fc6bded532
2 changed files with 82 additions and 75 deletions

View File

@@ -1,48 +1,26 @@
import asyncio
import copy
import inspect
import logging
import re
import inspect
from functools import partial, update_wrapper
from typing import Any, Awaitable, Callable, Dict, List, Optional, get_type_hints
import aiohttp
import asyncio
import yaml
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing import (
Any,
Awaitable,
Callable,
get_type_hints,
get_args,
get_origin,
Dict,
List,
Tuple,
Union,
Optional,
Type,
)
from functools import update_wrapper, partial
from fastapi import Request
from pydantic import BaseModel, Field, create_model
from langchain_core.utils.function_calling import (
convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
)
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
SRC_LOG_LEVELS,
)
from open_webui.models.tools import Tools
from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tool_module_by_id
from open_webui.env import (
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
)
import copy
from pydantic import BaseModel, Field, create_model
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -76,16 +54,18 @@ def get_tools(
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
if tool_id.startswith("server:"):
server_idx = int(tool_id.split(":")[1])
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
)
server_hash = tool_id.split(":", 1)[1] # Get hash part after "server:"
tool_server_data = None
for server in request.app.state.TOOL_SERVERS:
if server["idx"] == server_idx:
if server["hash"] == server_hash:
tool_server_data = server
break
assert tool_server_data is not None
# Get connection config directly from server data
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
)
specs = tool_server_data.get("specs", [])
for spec in specs:
@@ -102,7 +82,8 @@ def get_tools(
def make_tool_function(function_name, token, tool_server_data):
async def tool_function(**kwargs):
print(
f"Executing tool function {function_name} with params: {kwargs}"
f"Executing tool function {function_name} with params:"
f" {kwargs}"
)
return await execute_tool_server(
token=token,
@@ -190,8 +171,9 @@ def get_tools(
"spec": spec,
# Misc info
"metadata": {
"file_handler": hasattr(module, "file_handler")
and module.file_handler,
"file_handler": (
hasattr(module, "file_handler") and module.file_handler
),
"citation": hasattr(module, "citation") and module.citation,
},
}
@@ -440,6 +422,26 @@ def convert_openapi_to_tool_payload(openapi_spec):
return tool_payload
def compute_server_hash(server_url: str, server_path: str) -> str:
"""
Compute a hash for a tool server based on its URL and path.
Creates a consistent string representation by combining the normalized URL and path,
then uses Python's built-in hash() function for performance. The result is converted
to a positive integer and formatted as an 8-digit hexadecimal string for consistent
length and use as a server identifier.
Args:
server_url: The base URL of the server
server_path: The path to the OpenAPI spec
Returns:
A hash string to use as the server identifier
"""
combined = f"{server_url.rstrip('/')}/{server_path.lstrip('/')}"
return f"{hash(combined) & 0x7FFFFFFF:08x}"
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
headers = {
"Accept": "application/json",
@@ -486,7 +488,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
async def get_tool_servers_data(
servers: List[Dict[str, Any]], session_token: Optional[str] = None
) -> List[Dict[str, Any]]:
# Prepare list of enabled servers along with their original index
# Prepare list of enabled servers along with their hash-based identifier
server_entries = []
for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"):
@@ -504,6 +506,8 @@ async def get_tool_servers_data(
info = server.get("info", {})
server_hash = compute_server_hash(server.get("url", ""), openapi_path)
auth_type = server.get("auth_type", "bearer")
token = None
@@ -511,19 +515,21 @@ async def get_tool_servers_data(
token = server.get("key", "")
elif auth_type == "session":
token = session_token
server_entries.append((idx, server, full_url, info, token))
server_entries.append((idx, server_hash, server, full_url, info, token))
# Create async tasks to fetch data
tasks = [
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries
]
# Execute tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Build final results with index and server metadata
# Build final results with ID, hash, url, openapi, info, and specs
results = []
for (idx, server, url, info, _), response in zip(server_entries, responses):
for (idx, server_hash, server, url, info, _), response in zip(
server_entries, responses
):
if isinstance(response, Exception):
log.error(f"Failed to connect to {url} OpenAPI tool server")
continue
@@ -540,6 +546,7 @@ async def get_tool_servers_data(
results.append(
{
"idx": idx,
"hash": server_hash, # Hash-based ID
"url": server.get("url"),
"openapi": openapi_data,
"info": response.get("info"),