From ec9be0d20d0ff3a1210a1da6da65f41d8c269e53 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 4 Feb 2025 19:14:59 -0800 Subject: [PATCH] feat: native tool calling frontend --- backend/open_webui/main.py | 8 +++++ backend/open_webui/utils/middleware.py | 17 ++++++----- .../Settings/Advanced/AdvancedParams.svelte | 30 +++++++++++++++++++ .../components/chat/Settings/General.svelte | 1 + 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 0b453d2c0..265cb10c5 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -861,6 +861,7 @@ async def chat_completion( if model_id not in request.app.state.MODELS: raise Exception("Model not found") model = request.app.state.MODELS[model_id] + model_info = Models.get_model_by_id(model_id) # Check if user has access to the model if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": @@ -878,6 +879,13 @@ async def chat_completion( "files": form_data.get("files", None), "features": form_data.get("features", None), "variables": form_data.get("variables", None), + "model": model_info, + **( + {"function_calling": "native"} + if form_data.get("params", {}).get("function_calling") == "native" + or model_info.params.model_dump().get("function_calling") == "native" + else {} + ), } form_data["metadata"] = metadata diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 8cdd82196..81989a654 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -702,6 +702,7 @@ def apply_params_to_form_data(form_data, model): async def process_chat_payload(request, form_data, metadata, user, model): + form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") @@ -808,13 +809,15 @@ async def process_chat_payload(request, form_data, metadata, user, model): } form_data["metadata"] = metadata - try: - form_data, flags = await chat_completion_tools_handler( - request, form_data, user, models, extra_params - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) + if not form_data["metadata"].get("function_calling") == "native": + # If the function calling is not native, then call the tools function calling handler + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, user, models, extra_params + ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) try: form_data, flags = await chat_completion_files_handler(request, form_data, user) diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index a56ccab33..0ab38adf4 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -12,6 +12,7 @@ export let params = { // Advanced stream_response: null, // Set stream responses for this model individually + function_calling: null, seed: null, stop: null, temperature: null, @@ -81,6 +82,35 @@ +
+ +
+
+ {$i18n.t('Function Calling')} +
+ +
+
+
+