This commit is contained in:
Francesco Signoretti 2025-06-21 17:19:41 +05:30 committed by GitHub
commit be1f7ee7ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 75 deletions

View File

@ -1,29 +1,25 @@
import logging import logging
from pathlib import Path
from typing import Optional
import time
import re import re
import aiohttp import time
from pydantic import BaseModel, HttpUrl from typing import Optional
import aiohttp
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.tools import ( from open_webui.models.tools import (
ToolForm, ToolForm,
ToolModel, ToolModel,
ToolResponse, ToolResponse,
ToolUserResponse,
Tools, Tools,
ToolUserResponse,
) )
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.utils.tools import get_tool_servers_data from open_webui.utils.tools import get_tool_servers_data, get_tool_specs
from pydantic import BaseModel, HttpUrl
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -53,21 +49,25 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
tools.append( tools.append(
ToolUserResponse( ToolUserResponse(
**{ **{
"id": f"server:{server['idx']}", "id": f"server:{server['hash']}",
"user_id": f"server:{server['idx']}", "user_id": f"server:{server['hash']}",
"name": server.get("openapi", {}) "name": (
.get("info", {}) server.get("openapi", {})
.get("title", "Tool Server"),
"meta": {
"description": server.get("openapi", {})
.get("info", {}) .get("info", {})
.get("description", ""), .get("title", "Tool Server")
),
"meta": {
"description": (
server.get("openapi", {})
.get("info", {})
.get("description", "")
),
}, },
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ "access_control": (
server["idx"] request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
] .get("config", {})
.get("config", {}) .get("access_control", None)
.get("access_control", None), ),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
} }

View File

@ -1,48 +1,26 @@
import asyncio
import copy
import inspect import inspect
import logging import logging
import re import re
import inspect from functools import partial, update_wrapper
from typing import Any, Awaitable, Callable, Dict, List, Optional, get_type_hints
import aiohttp import aiohttp
import asyncio
import yaml import yaml
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing import (
Any,
Awaitable,
Callable,
get_type_hints,
get_args,
get_origin,
Dict,
List,
Tuple,
Union,
Optional,
Type,
)
from functools import update_wrapper, partial
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, Field, create_model
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
) )
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
SRC_LOG_LEVELS,
)
from open_webui.models.tools import Tools from open_webui.models.tools import Tools
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tool_module_by_id from open_webui.utils.plugin import load_tool_module_by_id
from open_webui.env import ( from pydantic import BaseModel, Field, create_model
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
)
import copy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -76,16 +54,18 @@ def get_tools(
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
if tool is None: if tool is None:
if tool_id.startswith("server:"): if tool_id.startswith("server:"):
server_idx = int(tool_id.split(":")[1]) server_hash = tool_id.split(":", 1)[1] # Get hash part after "server:"
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
)
tool_server_data = None tool_server_data = None
for server in request.app.state.TOOL_SERVERS: for server in request.app.state.TOOL_SERVERS:
if server["idx"] == server_idx: if server["hash"] == server_hash:
tool_server_data = server tool_server_data = server
break break
assert tool_server_data is not None assert tool_server_data is not None
# Get connection config directly from server data
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
)
specs = tool_server_data.get("specs", []) specs = tool_server_data.get("specs", [])
for spec in specs: for spec in specs:
@ -102,7 +82,8 @@ def get_tools(
def make_tool_function(function_name, token, tool_server_data): def make_tool_function(function_name, token, tool_server_data):
async def tool_function(**kwargs): async def tool_function(**kwargs):
print( print(
f"Executing tool function {function_name} with params: {kwargs}" f"Executing tool function {function_name} with params:"
f" {kwargs}"
) )
return await execute_tool_server( return await execute_tool_server(
token=token, token=token,
@ -190,8 +171,9 @@ def get_tools(
"spec": spec, "spec": spec,
# Misc info # Misc info
"metadata": { "metadata": {
"file_handler": hasattr(module, "file_handler") "file_handler": (
and module.file_handler, hasattr(module, "file_handler") and module.file_handler
),
"citation": hasattr(module, "citation") and module.citation, "citation": hasattr(module, "citation") and module.citation,
}, },
} }
@ -440,6 +422,26 @@ def convert_openapi_to_tool_payload(openapi_spec):
return tool_payload return tool_payload
def compute_server_hash(server_url: str, server_path: str) -> str:
"""
Compute a hash for a tool server based on its URL and path.
Creates a consistent string representation by combining the normalized URL and path,
then uses Python's built-in hash() function for performance. The result is converted
to a positive integer and formatted as an 8-digit hexadecimal string for consistent
length and use as a server identifier.
Args:
server_url: The base URL of the server
server_path: The path to the OpenAPI spec
Returns:
A hash string to use as the server identifier
"""
combined = f"{server_url.rstrip('/')}/{server_path.lstrip('/')}"
return f"{hash(combined) & 0x7FFFFFFF:08x}"
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
headers = { headers = {
"Accept": "application/json", "Accept": "application/json",
@ -486,7 +488,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
async def get_tool_servers_data( async def get_tool_servers_data(
servers: List[Dict[str, Any]], session_token: Optional[str] = None servers: List[Dict[str, Any]], session_token: Optional[str] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
# Prepare list of enabled servers along with their original index # Prepare list of enabled servers along with their hash-based identifier
server_entries = [] server_entries = []
for idx, server in enumerate(servers): for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"): if server.get("config", {}).get("enable"):
@ -504,6 +506,8 @@ async def get_tool_servers_data(
info = server.get("info", {}) info = server.get("info", {})
server_hash = compute_server_hash(server.get("url", ""), openapi_path)
auth_type = server.get("auth_type", "bearer") auth_type = server.get("auth_type", "bearer")
token = None token = None
@ -511,19 +515,21 @@ async def get_tool_servers_data(
token = server.get("key", "") token = server.get("key", "")
elif auth_type == "session": elif auth_type == "session":
token = session_token token = session_token
server_entries.append((idx, server, full_url, info, token)) server_entries.append((idx, server_hash, server, full_url, info, token))
# Create async tasks to fetch data # Create async tasks to fetch data
tasks = [ tasks = [
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries
] ]
# Execute tasks concurrently # Execute tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True) responses = await asyncio.gather(*tasks, return_exceptions=True)
# Build final results with index and server metadata # Build final results with ID, hash, url, openapi, info, and specs
results = [] results = []
for (idx, server, url, info, _), response in zip(server_entries, responses): for (idx, server_hash, server, url, info, _), response in zip(
server_entries, responses
):
if isinstance(response, Exception): if isinstance(response, Exception):
log.error(f"Failed to connect to {url} OpenAPI tool server") log.error(f"Failed to connect to {url} OpenAPI tool server")
continue continue
@ -540,6 +546,7 @@ async def get_tool_servers_data(
results.append( results.append(
{ {
"idx": idx, "idx": idx,
"hash": server_hash, # Hash-based ID
"url": server.get("url"), "url": server.get("url"),
"openapi": openapi_data, "openapi": openapi_data,
"info": response.get("info"), "info": response.get("info"),