refac: external tools server support

This commit is contained in:
Timothy Jaeryang Baek 2025-03-27 02:27:56 -07:00
parent 69dee19568
commit d1bc2cfa2f
9 changed files with 264 additions and 30 deletions

View File

@ -1052,6 +1052,7 @@ async def chat_completion(
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
"tool_ids": form_data.get("tool_ids", None),
"tool_servers": form_data.pop("tool_servers", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),

View File

@ -213,8 +213,9 @@ async def chat_completion_tools_handler(
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"name": tool_function_name,
"params": tool_function_params,
"tool": tool,
"server": tool.get("server", {}),
"session_id": metadata.get("session_id", None),
},
@ -224,17 +225,30 @@ async def chat_completion_tools_handler(
except Exception as e:
tool_output = str(e)
if isinstance(tool_output, dict):
tool_output = json.dumps(tool_output, indent=4)
if isinstance(tool_output, str):
if tools[tool_function_name]["citation"]:
tool_id = tools[tool_function_name].get("toolkit_id", "")
if tools[tool_function_name].get("citation", False):
sources.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
"name": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
),
},
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
"source": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
}
],
}
@ -246,13 +260,17 @@ async def chat_completion_tools_handler(
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
"source": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
}
],
}
)
if tools[tool_function_name]["file_handler"]:
if tools[tool_function_name].get("file_handler", False):
skip_files = True
# check if "tool_calls" in result
@ -788,7 +806,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_servers = form_data.get("tool_servers", None)
tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_servers=}")
@ -1824,8 +1842,9 @@ async def process_chat_response(
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"name": tool_name,
"params": tool_function_params,
"tool": tool,
"server": tool.get("server", {}),
"session_id": metadata.get(
"session_id", None

View File

@ -1,4 +1,5 @@
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { convertOpenApiToToolPayload } from '$lib/utils';
import { getOpenAIModelsDirect } from './openai';
export const getModels = async (
@ -256,6 +257,138 @@ export const stopTask = async (token: string, id: string) => {
return res;
};
export const getToolServerData = async (token: string, url: string) => {
let error = null;
const res = await fetch(`${url}/openapi.json`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
return null;
});
if (error) {
throw error;
}
const data = {
openapi: res,
info: res.info,
specs: convertOpenApiToToolPayload(res)
};
console.log(data);
return data;
};
export const getToolServersData = async (servers: object[]) => {
return await Promise.all(
servers
.filter(async (server) => server?.config?.enable)
.map(async (server) => {
const data = await getToolServerData(server?.key, server?.url).catch((err) => {
console.error(err);
return null;
});
if (data) {
const { openapi, info, specs } = data;
return {
url: server?.url,
openapi: openapi,
info: info,
specs: specs
};
}
})
);
};
export const executeToolServer = async (
token: string,
url: string,
name: string,
params: object,
serverData: { openapi: any; info: any; specs: any }
) => {
let error = null;
try {
// Find the matching operationId in the OpenAPI specification
const matchingRoute = Object.entries(serverData.openapi.paths).find(([path, methods]) => {
return Object.entries(methods).some(
([method, operation]: any) => operation.operationId === name
);
});
if (!matchingRoute) {
throw new Error(`No matching route found for operationId: ${name}`);
}
const [route, methods] = matchingRoute;
const methodEntry = Object.entries(methods).find(
([method, operation]: any) => operation.operationId === name
);
if (!methodEntry) {
throw new Error(`No matching method found for operationId: ${name}`);
}
const [httpMethod, operation]: [string, any] = methodEntry;
// Replace path parameters in the URL
let finalUrl = `${url}${route}`;
if (operation.parameters) {
Object.entries(params).forEach(([key, value]) => {
finalUrl = finalUrl.replace(`{${key}}`, encodeURIComponent(value as string));
});
}
// Headers and request options
const headers = {
...(token && { authorization: `Bearer ${token}` }),
'Content-Type': 'application/json'
};
let requestOptions: RequestInit = {
method: httpMethod.toUpperCase(),
headers
};
// Handle request body for POST, PUT, PATCH
if (['post', 'put', 'patch'].includes(httpMethod.toLowerCase()) && operation.requestBody) {
requestOptions.body = JSON.stringify(params);
}
// Execute the request
const res = await fetch(finalUrl, requestOptions);
if (!res.ok) {
throw new Error(`HTTP error! Status: ${res.status}`);
}
return await res.json();
} catch (err: any) {
error = err.message;
console.error('API Request Error:', error);
return { error };
}
};
export const getTaskConfig = async (token: string = '') => {
let error = null;

View File

@ -35,7 +35,8 @@
showOverview,
chatTitle,
showArtifacts,
tools
tools,
toolServers
} from '$lib/stores';
import {
convertMessagesToHistory,
@ -120,8 +121,6 @@
let webSearchEnabled = false;
let codeInterpreterEnabled = false;
let toolServers = [];
let chat = null;
let tags = [];
@ -194,8 +193,6 @@
setToolIds();
}
$: toolServers = ($settings?.toolServers ?? []).filter((server) => server?.config?.enable);
const setToolIds = async () => {
if (!$tools) {
tools.set(await getTools(localStorage.token));
@ -1570,6 +1567,7 @@
files: (files?.length ?? 0) > 0 ? files : undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
tool_servers: $toolServers,
features: {
image_generation:
@ -2038,7 +2036,7 @@
bind:codeInterpreterEnabled
bind:webSearchEnabled
bind:atSelectedModel
{toolServers}
toolServers={$toolServers}
transparentBackground={$settings?.backgroundImageUrl ?? false}
{stopResponse}
{createMessagePair}
@ -2092,7 +2090,7 @@
bind:webSearchEnabled
bind:atSelectedModel
transparentBackground={$settings?.backgroundImageUrl ?? false}
{toolServers}
toolServers={$toolServers}
{stopResponse}
{createMessagePair}
on:upload={async (e) => {

View File

@ -1,12 +1,12 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
import { getModels as _getModels } from '$lib/apis';
import { getModels as _getModels, getToolServersData } from '$lib/apis';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import { models, settings, user } from '$lib/stores';
import { models, settings, toolServers, user } from '$lib/stores';
import Switch from '$lib/components/common/Switch.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
@ -30,6 +30,8 @@
await saveSettings({
toolServers: servers
});
toolServers.set(await getToolServersData($settings?.toolServers ?? []));
};
onMount(async () => {

View File

@ -58,6 +58,8 @@ export const knowledge: Writable<null | Document[]> = writable(null);
export const tools = writable(null);
export const functions = writable(null);
export const toolServers = writable([]);
export const banners: Writable<Banner[]> = writable([]);
export const settings: Writable<Settings> = writable({});

View File

@ -1070,3 +1070,59 @@ export const getLineCount = (text) => {
console.log(typeof text);
return text ? text.split('\n').length : 0;
};
export const convertOpenApiToToolPayload = (openApiSpec) => {
const toolPayload = [];
for (const [path, methods] of Object.entries(openApiSpec.paths)) {
for (const [method, operation] of Object.entries(methods)) {
const tool = {
type: 'function',
name: operation.operationId,
description: operation.summary || 'No description available.',
parameters: {
type: 'object',
properties: {},
required: []
}
};
// Extract path or query parameters
if (operation.parameters) {
operation.parameters.forEach((param) => {
tool.parameters.properties[param.name] = {
type: param.schema.type,
description: param.schema.description || ''
};
if (param.required) {
tool.parameters.required.push(param.name);
}
});
}
// Extract parameters from requestBody if applicable
if (operation.requestBody) {
const ref = operation.requestBody.content['application/json'].schema['$ref'];
if (ref) {
const schemaName = ref.split('/').pop();
const schemaDef = openApiSpec.components.schemas[schemaName];
if (schemaDef && schemaDef.properties) {
for (const [prop, details] of Object.entries(schemaDef.properties)) {
tool.parameters.properties[prop] = {
type: details.type,
description: details.description || ''
};
}
tool.parameters.required = schemaDef.required || [];
}
}
}
toolPayload.push(tool);
}
}
return toolPayload;
};

View File

@ -12,7 +12,7 @@
import { getKnowledgeBases } from '$lib/apis/knowledge';
import { getFunctions } from '$lib/apis/functions';
import { getModels, getVersionUpdates } from '$lib/apis';
import { getModels, getToolServersData, getVersionUpdates } from '$lib/apis';
import { getAllTags } from '$lib/apis/chats';
import { getPrompts } from '$lib/apis/prompts';
import { getTools } from '$lib/apis/tools';
@ -35,7 +35,8 @@
banners,
showSettings,
showChangelog,
temporaryChatEnabled
temporaryChatEnabled,
toolServers
} from '$lib/stores';
import Sidebar from '$lib/components/layout/Sidebar.svelte';
@ -43,6 +44,7 @@
import ChangelogModal from '$lib/components/ChangelogModal.svelte';
import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
import UpdateInfoToast from '$lib/components/layout/UpdateInfoToast.svelte';
import { get } from 'svelte/store';
const i18n = getContext('i18n');
@ -99,8 +101,10 @@
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null)
)
);
banners.set(await getBanners(localStorage.token));
tools.set(await getTools(localStorage.token));
toolServers.set(await getToolServersData($settings?.toolServers ?? []));
document.addEventListener('keydown', async function (event) {
const isCtrlPressed = event.ctrlKey || event.metaKey; // metaKey is for Cmd key on Mac

View File

@ -31,7 +31,7 @@
import { page } from '$app/stores';
import { Toaster, toast } from 'svelte-sonner';
import { getBackendConfig } from '$lib/apis';
import { executeToolServer, getBackendConfig } from '$lib/apis';
import { getSessionUser } from '$lib/apis/auths';
import '../tailwind.css';
@ -205,17 +205,36 @@
const executeTool = async (data, cb) => {
console.log(data);
// TODO: MCP (SSE) support
// TODO: API Server support
if (cb) {
cb(
JSON.parse(
JSON.stringify({
result: null
})
)
);
const toolServer = $settings?.toolServers?.find((server) => server.url === data.server?.url);
if (toolServer) {
const res = await executeToolServer(
toolServer.key,
toolServer.url,
data?.name,
data?.params,
toolServer
).catch((error) => {
console.error('executeToolServer', error);
return {
error: error
};
});
if (cb) {
cb(JSON.parse(JSON.stringify(res)));
}
} else {
if (cb) {
cb(
JSON.parse(
JSON.stringify({
error: 'Tool Server Not Found'
})
)
);
}
}
};