From 25e06c849cddbbb63bdbb0bccf579c824dddca65 Mon Sep 17 00:00:00 2001 From: smonux Date: Sun, 13 Oct 2024 21:49:03 +0200 Subject: [PATCH] fix: function calling for haiku --- backend/open_webui/main.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e24afb8b1..ffb7b962f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -476,7 +476,9 @@ async def handle_streaming_response(request: Request, response: Response, if peek == b'data: [DONE]\n' and len(citations) > 0 : yield wrap_item(json.dumps({ "citations" : citations})) - if peek_json is None or not 'tool_calls' in peek_json['choices'][0]['delta']: + if peek_json is None or \ + not 'choices' in peek_json or \ + not 'tool_calls' in peek_json['choices'][0]['delta']: yield peek continue @@ -491,8 +493,9 @@ async def handle_streaming_response(request: Request, response: Response, delta = extract_json(data) if delta is None or \ + not 'choices' in delta or \ not 'tool_calls' in delta['choices'][0]['delta'] or \ - delta['choices'][0]['finish_reason'] is not None: + delta['choices'][0].get('finish_reason', None) is not None: continue i = delta['choices'][0]['delta']['tool_calls'][0]['index'] @@ -511,7 +514,8 @@ async def handle_streaming_response(request: Request, response: Response, for tool_call in tool_calls: tool_function_name = tool_call["function"]["name"] - tool_function_params = json.loads(tool_call["function"]["arguments"]) + tool_function_params = extract_json(tool_call["function"]["arguments"]) + tool_function_params = {} if tool_function_params is None else tool_function_params try: tool_output = await tools[tool_function_name]["callable"](**tool_function_params) @@ -548,6 +552,7 @@ async def handle_streaming_response(request: Request, response: Response, except StopAsyncIteration: pass except Exception as e: + import pdb; pdb.set_trace() log.exception(f"Error: {e}") return StreamingResponse( @@ -573,7 +578,9 @@ async def handle_nonstreaming_response(request: Request, response: Response, tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) for tool_call in tool_calls: tool_function_name = tool_call["function"]["name"] - tool_function_params = json.loads(tool_call["function"]["arguments"]) + tool_function_params = extract_json(tool_call["function"]["arguments"]) + + tool_function_params = {} if tool_function_params is None else tool_function_params try: tool_output = await tools[tool_function_name]["callable"](**tool_function_params)