From 3724e888c638ace8facc36066df11638f608c65e Mon Sep 17 00:00:00 2001 From: Samuel Date: Sun, 3 Nov 2024 08:50:33 +0000 Subject: [PATCH] fixes for cohere models --- backend/open_webui/main.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9f2d81940..96585a86c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -579,10 +579,9 @@ async def handle_nonstreaming_response(request: Request, response: Response, is_ollama = True is_openai = not is_ollama - log.debug(f"smonux 1: { response_dict=} ") while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \ - (is_openai and "tool_calls" in response_dict["choices"][0].get("message",{}) ): + (is_openai and "tool_calls" in response_dict.get("choices", [{}])[0].get("message",{}) ): if is_ollama: message = response_dict.get("message", {}) tool_calls = message.get("tool_calls", []) @@ -593,6 +592,10 @@ async def handle_nonstreaming_response(request: Request, response: Response, tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) for tool_call in tool_calls: + # fix for cohere + if 'index' in tool_call: + del tool_call['index'] + tool_function_name = tool_call["function"]["name"] if not tool_call["function"]["arguments"]: tool_function_params = {} @@ -602,13 +605,12 @@ async def handle_nonstreaming_response(request: Request, response: Response, if is_ollama: tool_function_params = tool_call["function"]["arguments"] - try: tool_output = await tools[tool_function_name]["callable"](**tool_function_params) except Exception as e: tool_output = str(e) - if tools[tool_function_name]["citation"]: + if tools.get(tool_function_name, {}).get("citation", False): citations.append( { "source": { "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"