fix: now it handles Openrouter providers correctly

This commit is contained in:
smonux 2024-10-13 08:37:01 +02:00
parent 0a902fc691
commit faf9915a79

View File

@ -433,8 +433,11 @@ def update_body_request(request: Request,
return None return None
def extract_json(binary:bytes)->dict: def extract_json(binary:bytes)->dict:
s = binary.decode("utf-8") try:
return json.loads(s[ s.find("{") : s.rfind("}") + 1 ]) s = binary.decode("utf-8")
return json.loads(s[ s.find("{") : s.rfind("}") + 1 ])
except Exception as e:
return None
def fill_with_delta(fcall_dict:dict, delta:dict) -> None: def fill_with_delta(fcall_dict:dict, delta:dict) -> None:
if not 'delta' in delta['choices'][0]: if not 'delta' in delta['choices'][0]:
@ -469,12 +472,9 @@ async def handle_streaming_response(request: Request, response: Response,
generator = original_generator generator = original_generator
try: try:
while True: while True:
peek = await generator.__anext__() peek = await anext(generator)
if peek in (b'\n', b'data: [DONE]\n'):
yield peek
continue
peek_json = extract_json(peek) peek_json = extract_json(peek)
if not 'tool_calls' in peek_json['choices'][0]['delta']: if peek_json is None or not 'tool_calls' in peek_json['choices'][0]['delta']:
yield peek yield peek
continue continue
@ -486,10 +486,11 @@ async def handle_streaming_response(request: Request, response: Response,
fill_with_delta(tool_calls[current_index], peek_json) fill_with_delta(tool_calls[current_index], peek_json)
async for data in generator: async for data in generator:
if data in (b'\n', b'data: [DONE]\n'):
continue
delta = extract_json(data) delta = extract_json(data)
if delta['choices'][0]['finish_reason'] is not None:
if delta is None or \
not 'tool_calls' in delta['choices'][0]['delta'] or \
delta['choices'][0]['finish_reason'] is not None:
continue continue
i = delta['choices'][0]['delta']['tool_calls'][0]['index'] i = delta['choices'][0]['delta']['tool_calls'][0]['index']
@ -515,6 +516,19 @@ async def handle_streaming_response(request: Request, response: Response,
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
# With several tools potentially called together may be this doesn't make sense
if tools[tool_function_name]["file_handler"] and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
if tools[tool_function_name]["citation"]:
citation = {
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
}
yield wrap_item(json.dumps(citation))
# Append the tool output to the messages # Append the tool output to the messages
body["messages"].append({ body["messages"].append({
"role": "tool", "role": "tool",
@ -527,6 +541,7 @@ async def handle_streaming_response(request: Request, response: Response,
log.debug("calling the model again with tool output included") log.debug("calling the model again with tool output included")
update_body_request(request, body) update_body_request(request, body)
response = await generate_chat_completions(form_data = body, user = user ) response = await generate_chat_completions(form_data = body, user = user )
# body_iterator here does not have __anext_() so it has to be done this way
generator = response.body_iterator.__aiter__() generator = response.body_iterator.__aiter__()
except StopAsyncIteration as sie: except StopAsyncIteration as sie: