diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 3fbeb6c84..7fc4e3983 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1052,6 +1052,7 @@ async def chat_completion( "message_id": form_data.pop("id", None), "session_id": form_data.pop("session_id", None), "tool_ids": form_data.get("tool_ids", None), + "tool_servers": form_data.pop("tool_servers", None), "files": form_data.get("files", None), "features": form_data.get("features", None), "variables": form_data.get("variables", None), diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index f6d81214e..77b01bdfc 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -213,8 +213,9 @@ async def chat_completion_tools_handler( "type": "execute:tool", "data": { "id": str(uuid4()), - "tool": tool, + "name": tool_function_name, "params": tool_function_params, + "tool": tool, "server": tool.get("server", {}), "session_id": metadata.get("session_id", None), }, @@ -224,17 +225,30 @@ async def chat_completion_tools_handler( except Exception as e: tool_output = str(e) + if isinstance(tool_output, dict): + tool_output = json.dumps(tool_output, indent=4) + if isinstance(tool_output, str): - if tools[tool_function_name]["citation"]: + tool_id = tools[tool_function_name].get("toolkit_id", "") + if tools[tool_function_name].get("citation", False): + sources.append( { "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + "name": ( + f"TOOL:" + f"{tool_id}/{tool_function_name}" + if tool_id + else f"{tool_function_name}" + ), }, "document": [tool_output], "metadata": [ { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + "source": ( + f"TOOL:" + f"{tool_id}/{tool_function_name}" + if tool_id + else f"{tool_function_name}" + ) } ], } @@ -246,13 +260,17 @@ async def chat_completion_tools_handler( "document": [tool_output], "metadata": [ { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + "source": ( + f"TOOL:" + f"{tool_id}/{tool_function_name}" + if tool_id + else f"{tool_function_name}" + ) } ], } ) - if tools[tool_function_name]["file_handler"]: + if tools[tool_function_name].get("file_handler", False): skip_files = True # check if "tool_calls" in result @@ -788,7 +806,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Server side tools tool_ids = metadata.get("tool_ids", None) # Client side tools - tool_servers = form_data.get("tool_servers", None) + tool_servers = metadata.get("tool_servers", None) log.debug(f"{tool_ids=}") log.debug(f"{tool_servers=}") @@ -1824,8 +1842,9 @@ async def process_chat_response( "type": "execute:tool", "data": { "id": str(uuid4()), - "tool": tool, + "name": tool_name, "params": tool_function_params, + "tool": tool, "server": tool.get("server", {}), "session_id": metadata.get( "session_id", None diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 674f24267..2e6e19e6a 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; +import { convertOpenApiToToolPayload } from '$lib/utils'; import { getOpenAIModelsDirect } from './openai'; export const getModels = async ( @@ -256,6 +257,138 @@ export const stopTask = async (token: string, id: string) => { return res; }; +export const getToolServerData = async (token: string, url: string) => { + let error = null; + + const res = await fetch(`${url}/openapi.json`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + const data = { + openapi: res, + info: res.info, + specs: convertOpenApiToToolPayload(res) + }; + + console.log(data); + return data; +}; + +export const getToolServersData = async (servers: object[]) => { + return await Promise.all( + servers + .filter(async (server) => server?.config?.enable) + .map(async (server) => { + const data = await getToolServerData(server?.key, server?.url).catch((err) => { + console.error(err); + return null; + }); + + if (data) { + const { openapi, info, specs } = data; + return { + url: server?.url, + openapi: openapi, + info: info, + specs: specs + }; + } + }) + ); +}; + +export const executeToolServer = async ( + token: string, + url: string, + name: string, + params: object, + serverData: { openapi: any; info: any; specs: any } +) => { + let error = null; + + try { + // Find the matching operationId in the OpenAPI specification + const matchingRoute = Object.entries(serverData.openapi.paths).find(([path, methods]) => { + return Object.entries(methods).some( + ([method, operation]: any) => operation.operationId === name + ); + }); + + if (!matchingRoute) { + throw new Error(`No matching route found for operationId: ${name}`); + } + + const [route, methods] = matchingRoute; + const methodEntry = Object.entries(methods).find( + ([method, operation]: any) => operation.operationId === name + ); + + if (!methodEntry) { + throw new Error(`No matching method found for operationId: ${name}`); + } + + const [httpMethod, operation]: [string, any] = methodEntry; + + // Replace path parameters in the URL + let finalUrl = `${url}${route}`; + if (operation.parameters) { + Object.entries(params).forEach(([key, value]) => { + finalUrl = finalUrl.replace(`{${key}}`, encodeURIComponent(value as string)); + }); + } + + // Headers and request options + const headers = { + ...(token && { authorization: `Bearer ${token}` }), + 'Content-Type': 'application/json' + }; + + let requestOptions: RequestInit = { + method: httpMethod.toUpperCase(), + headers + }; + + // Handle request body for POST, PUT, PATCH + if (['post', 'put', 'patch'].includes(httpMethod.toLowerCase()) && operation.requestBody) { + requestOptions.body = JSON.stringify(params); + } + + // Execute the request + const res = await fetch(finalUrl, requestOptions); + if (!res.ok) { + throw new Error(`HTTP error! Status: ${res.status}`); + } + + return await res.json(); + } catch (err: any) { + error = err.message; + console.error('API Request Error:', error); + return { error }; + } +}; + export const getTaskConfig = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index fe733d616..a6337ef8e 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -35,7 +35,8 @@ showOverview, chatTitle, showArtifacts, - tools + tools, + toolServers } from '$lib/stores'; import { convertMessagesToHistory, @@ -120,8 +121,6 @@ let webSearchEnabled = false; let codeInterpreterEnabled = false; - let toolServers = []; - let chat = null; let tags = []; @@ -194,8 +193,6 @@ setToolIds(); } - $: toolServers = ($settings?.toolServers ?? []).filter((server) => server?.config?.enable); - const setToolIds = async () => { if (!$tools) { tools.set(await getTools(localStorage.token)); @@ -1570,6 +1567,7 @@ files: (files?.length ?? 0) > 0 ? files : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + tool_servers: $toolServers, features: { image_generation: @@ -2038,7 +2036,7 @@ bind:codeInterpreterEnabled bind:webSearchEnabled bind:atSelectedModel - {toolServers} + toolServers={$toolServers} transparentBackground={$settings?.backgroundImageUrl ?? false} {stopResponse} {createMessagePair} @@ -2092,7 +2090,7 @@ bind:webSearchEnabled bind:atSelectedModel transparentBackground={$settings?.backgroundImageUrl ?? false} - {toolServers} + toolServers={$toolServers} {stopResponse} {createMessagePair} on:upload={async (e) => { diff --git a/src/lib/components/chat/Settings/Tools.svelte b/src/lib/components/chat/Settings/Tools.svelte index 740e4712f..a900b5a46 100644 --- a/src/lib/components/chat/Settings/Tools.svelte +++ b/src/lib/components/chat/Settings/Tools.svelte @@ -1,12 +1,12 @@