From 5a7efad59c21f69ff389ef496050042f92afc12d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 26 Mar 2025 00:40:24 -0700 Subject: [PATCH] refac: tools --- backend/open_webui/utils/middleware.py | 77 +++++++++++++++++++++----- backend/open_webui/utils/tools.py | 3 + src/routes/+layout.svelte | 19 +++++++ 3 files changed, 86 insertions(+), 13 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e04bb17cf..89fa10fbb 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) async def chat_completion_tools_handler( - request: Request, body: dict, user: UserModel, models, tools + request: Request, body: dict, extra_params: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: async def get_content_from_response(response) -> Optional[str]: content = None @@ -135,6 +135,9 @@ async def chat_completion_tools_handler( "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } + event_caller = extra_params["__event_call__"] + metadata = extra_params["__metadata__"] + task_model_id = get_task_model_id( body["model"], request.app.state.config.TASK_MODEL, @@ -189,17 +192,33 @@ async def chat_completion_tools_handler( tool_function_params = tool_call.get("parameters", {}) try: - spec = tools[tool_function_name].get("spec", {}) + tool = tools[tool_function_name] + + spec = tool.get("spec", {}) allowed_params = ( spec.get("parameters", {}).get("properties", {}).keys() ) - tool_function = tools[tool_function_name]["callable"] + tool_function = tool["callable"] tool_function_params = { k: v for k, v in tool_function_params.items() if k in allowed_params } - tool_output = await tool_function(**tool_function_params) + + if tool.get("direct", False): + tool_output = await tool_function(**tool_function_params) + else: + tool_output = await event_caller( + { + "type": "execute:tool", + "data": { + "id": str(uuid4()), + "tool": tool, + "params": tool_function_params, + "session_id": metadata.get("session_id", None), + }, + } + ) except Exception as e: tool_output = str(e) @@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model): } form_data["metadata"] = metadata + # Server side tools tool_ids = metadata.get("tool_ids", None) + # Client side tools + tool_specs = form_data.get("tool_specs", None) + log.debug(f"{tool_ids=}") + log.debug(f"{tool_specs=}") + + tools_dict = {} if tool_ids: - # If tool_ids field is present, then get the tools - tools = get_tools( + tools_dict = get_tools( request, tool_ids, user, @@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__files__": metadata.get("files", []), }, ) - log.info(f"{tools=}") + log.info(f"{tools_dict=}") + if tool_specs: + for tool in tool_specs: + callable = tool.pop("callable", None) + tools_dict[tool["name"]] = { + "direct": True, + "callable": callable, + "spec": tool, + } + + if tools_dict: if metadata.get("function_calling") == "native": # If the function calling is native, then call the tools function calling handler - metadata["tools"] = tools + metadata["tools"] = tools_dict form_data["tools"] = [ {"type": "function", "function": tool.get("spec", {})} - for tool in tools.values() + for tool in tools_dict.values() ] else: # If the function calling is not native, then call the tools function calling handler try: form_data, flags = await chat_completion_tools_handler( - request, form_data, user, models, tools + request, form_data, extra_params, user, models, tools_dict ) sources.extend(flags.get("sources", [])) @@ -1774,9 +1809,25 @@ async def process_chat_response( for k, v in tool_function_params.items() if k in allowed_params } - tool_result = await tool_function( - **tool_function_params - ) + + if tool.get("direct", False): + tool_result = await tool_function( + **tool_function_params + ) + else: + tool_result = await event_caller( + { + "type": "execute:tool", + "data": { + "id": str(uuid4()), + "tool": tool, + "params": tool_function_params, + "session_id": metadata.get( + "session_id", None + ), + }, + } + ) except Exception as e: tool_result = str(e) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index c44c30402..53ecf4d0e 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,6 +1,9 @@ import inspect import logging import re +import inspect +import uuid + from typing import Any, Awaitable, Callable, get_type_hints from functools import update_wrapper, partial diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index b1567fd9e..79030e731 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -203,6 +203,22 @@ }; }; + const executeTool = async (data, cb) => { + console.log(data); + // TODO: MCP (SSE) support + // TODO: API Server support + + if (cb) { + cb( + JSON.parse( + JSON.stringify({ + result: null + }) + ) + ); + } + }; + const chatEventHandler = async (event, cb) => { const chat = $page.url.pathname.includes(`/c/${event.chat_id}`); @@ -256,6 +272,9 @@ if (type === 'execute:python') { console.log('execute:python', data); executePythonAsWorker(data.id, data.code, cb); + } else if (type === 'execute:tool') { + console.log('execute:tool', data); + executeTool(data, cb); } else if (type === 'request:chat:completion') { console.log(data, $socket.id); const { session_id, channel, form_data, model } = data;