diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index cc742ca42..46ab075e1 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -351,7 +351,7 @@ async def get_content_from_response(response) -> Optional[str]: return content def get_tools_body( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, extra_params: dict, models ) -> tuple[dict, dict]: metadata = body.get("metadata", {}) @@ -360,14 +360,20 @@ def get_tools_body( if not tool_ids: return body, {} - task_model_id = get_task_model_id(body["model"]) + task_model_id = get_task_model_id( + body["model"], + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + tools = get_tools( webui_app, tool_ids, user, { **extra_params, - "__model__": app.state.MODELS[task_model_id], + "__model__": models[task_model_id], "__messages__": body["messages"], "__files__": metadata.get("files", []), }, @@ -861,7 +867,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): } body["metadata"] = metadata - body, tools = get_tools_body(body, user, extra_params) + body, tools = get_tools_body(body, user, extra_params, models) if model["owned_by"] == "ollama" and \ body["metadata"]["native_tool_call"] and \