fix: function calling for haiku

This commit is contained in:
smonux 2024-10-13 21:49:03 +02:00
parent 8c920c99c4
commit 25e06c849c

View File

@ -476,7 +476,9 @@ async def handle_streaming_response(request: Request, response: Response,
if peek == b'data: [DONE]\n' and len(citations) > 0 : if peek == b'data: [DONE]\n' and len(citations) > 0 :
yield wrap_item(json.dumps({ "citations" : citations})) yield wrap_item(json.dumps({ "citations" : citations}))
if peek_json is None or not 'tool_calls' in peek_json['choices'][0]['delta']: if peek_json is None or \
not 'choices' in peek_json or \
not 'tool_calls' in peek_json['choices'][0]['delta']:
yield peek yield peek
continue continue
@ -491,8 +493,9 @@ async def handle_streaming_response(request: Request, response: Response,
delta = extract_json(data) delta = extract_json(data)
if delta is None or \ if delta is None or \
not 'choices' in delta or \
not 'tool_calls' in delta['choices'][0]['delta'] or \ not 'tool_calls' in delta['choices'][0]['delta'] or \
delta['choices'][0]['finish_reason'] is not None: delta['choices'][0].get('finish_reason', None) is not None:
continue continue
i = delta['choices'][0]['delta']['tool_calls'][0]['index'] i = delta['choices'][0]['delta']['tool_calls'][0]['index']
@ -511,7 +514,8 @@ async def handle_streaming_response(request: Request, response: Response,
for tool_call in tool_calls: for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"] tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"]) tool_function_params = extract_json(tool_call["function"]["arguments"])
tool_function_params = {} if tool_function_params is None else tool_function_params
try: try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
@ -548,6 +552,7 @@ async def handle_streaming_response(request: Request, response: Response,
except StopAsyncIteration: except StopAsyncIteration:
pass pass
except Exception as e: except Exception as e:
import pdb; pdb.set_trace()
log.exception(f"Error: {e}") log.exception(f"Error: {e}")
return StreamingResponse( return StreamingResponse(
@ -573,7 +578,9 @@ async def handle_nonstreaming_response(request: Request, response: Response,
tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) tool_calls = response_dict["choices"][0]["message"].get("tool_calls", [])
for tool_call in tool_calls: for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"] tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"]) tool_function_params = extract_json(tool_call["function"]["arguments"])
tool_function_params = {} if tool_function_params is None else tool_function_params
try: try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) tool_output = await tools[tool_function_name]["callable"](**tool_function_params)