fix: Non streaming response is working

This commit is contained in:
smonux 2024-10-10 22:20:17 +02:00
parent 6777210d1f
commit c19c7b56cf

View File

@ -437,7 +437,6 @@ async def handle_streaming_response(request: Request, response: Response,
tools: dict, tools: dict,
data_items: list, data_items: list,
call_next) -> StreamingResponse: call_next) -> StreamingResponse:
log.debug(f"smonux 21 {response.headers}")
"""content_type = response.headers["Content-Type"] """content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type is_openai = "text/event-stream" in content_type
@ -460,10 +459,11 @@ async def handle_streaming_response(request: Request, response: Response,
yield data yield data
log.debug(f"smonux 24 {full_response}") log.debug(f"smonux 24 {full_response}")
full_response_dict = json.loads(full_response[full_response.find("{"): full_response.rfind("}") + 1]) full_response_dict = json.loads(full_response)
if full_response_dict["choices"][0]["finish_reason"] != "tool_calls": if full_response_dict["choices"][0]["finish_reason"] != "tool_calls":
break break
body["messages"].append(response_dict["choices"][0]["message"])
for tool_call in full_response_dict["choices"][0].get("tool_calls", []): for tool_call in full_response_dict["choices"][0].get("tool_calls", []):
tool_function_name = tool_call["function"]["name"] tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"]) tool_function_params = json.loads(tool_call["function"]["arguments"])
@ -477,9 +477,9 @@ async def handle_streaming_response(request: Request, response: Response,
# Append the tool output to the messages # Append the tool output to the messages
body["messages"].append({ body["messages"].append({
"role": "tool", "role": "tool",
"tool_call_id" : tool_call["id"],
"name": tool_function_name, "name": tool_function_name,
"content": tool_output, "content": tool_output
"tool_call_id" : tool_call["id"]
}) })
update_body_request(request, body) update_body_request(request, body)
response = await call_next(request) response = await call_next(request)
@ -490,8 +490,7 @@ async def handle_streaming_response(request: Request, response: Response,
headers=dict(response.headers), headers=dict(response.headers),
) )
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> Response: async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse:
# It only should be one response # It only should be one response
async for data in response.body_iterator: async for data in response.body_iterator:
content = data content = data
@ -499,36 +498,30 @@ async def handle_nonstreaming_response(request: Request, response: Response, too
body = json.loads(request._body) body = json.loads(request._body)
while response_dict["choices"][0]["finish_reason"] == "tool_calls": while response_dict["choices"][0]["finish_reason"] == "tool_calls":
for tool_call in response_dict["choices"][0]["message"].get("tool_calls", []): body["messages"].append(response_dict["choices"][0]["message"])
log.debug(f"smonux 12 {tool_call}") tool_calls = response_dict["choices"][0]["message"].get("tool_calls", [])
for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"] tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"]) tool_function_params = json.loads(tool_call["function"]["arguments"])
log.debug(f"smonux 13 {tool_function_params=}")
try: try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
log.debug(f"smonux 14 {tool_output=}")
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
# Append the tool output to the messages # Append the tool output to the messages
body["messages"].append({ body["messages"].append({
"role": "tool", "role": "tool",
"tool_call_id" : tool_call["id"],
"name": tool_function_name, "name": tool_function_name,
"content": tool_output "content": tool_output
}) })
# Make another request to the model with the updated context # Make another request to the model with the updated context
update_body_request(request, body) update_body_request(request, body)
response = await generate_chat_completions(form_data = body, user = user ) response_dict = await generate_chat_completions(form_data = body, user = user )
# response = await call_next(request)
async for data in response.body_iterator:
content = data
response_dict = json.loads(content)
import time; time.sleep(0.5)
return response return JSONResponse(content = response_dict)
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
@ -742,7 +735,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
} }
body["metadata"] = metadata body["metadata"] = metadata
log.debug("smonux 00")
body, tools = get_tools_body(body, user, extra_params) body, tools = get_tools_body(body, user, extra_params)
try: try: