incremental fixes nonstreaming

This commit is contained in:
smonux 2024-10-09 06:58:44 +02:00
parent eb1366b4bc
commit 6777210d1f

View File

@ -490,17 +490,25 @@ async def handle_streaming_response(request: Request, response: Response,
headers=dict(response.headers), headers=dict(response.headers),
) )
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict) -> Response: async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> Response:
response_dict = json.loads(response.content)
# It only should be one response
async for data in response.body_iterator:
content = data
response_dict = json.loads(content)
body = json.loads(request._body) body = json.loads(request._body)
while response_dict["choices"][0]["finish_reason"] == "tool_calls": while response_dict["choices"][0]["finish_reason"] == "tool_calls":
for tool_call in response_dict["choices"][0]["tool_calls"]: for tool_call in response_dict["choices"][0]["message"].get("tool_calls", []):
log.debug(f"smonux 12 {tool_call}")
tool_function_name = tool_call["function"]["name"] tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"]) tool_function_params = json.loads(tool_call["function"]["arguments"])
log.debug(f"smonux 13 {tool_function_params=}")
try: try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
log.debug(f"smonux 14 {tool_output=}")
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
@ -513,11 +521,14 @@ async def handle_nonstreaming_response(request: Request, response: Response, too
# Make another request to the model with the updated context # Make another request to the model with the updated context
update_body_request(request, body) update_body_request(request, body)
response = await call_next(request) response = await generate_chat_completions(form_data = body, user = user )
response_dict = json.loads(response) # response = await call_next(request)
async for data in response.body_iterator:
content = data
response_dict = json.loads(content)
import time; time.sleep(0.5)
return response return response
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
@ -731,6 +742,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
} }
body["metadata"] = metadata body["metadata"] = metadata
log.debug("smonux 00")
body, tools = get_tools_body(body, user, extra_params) body, tools = get_tools_body(body, user, extra_params)
try: try:
@ -776,11 +788,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0: if len(citations) > 0:
data_items.append({"citations": citations}) data_items.append({"citations": citations})
log.debug("smonux 10")
update_body_request(request, body) update_body_request(request, body)
first_response = await call_next(request) first_response = await call_next(request)
#if body.get("stream", False) is False: if body.get("stream", False) is False:
# return await handle_nonstreaming_response(request, first_response, tools) return await handle_nonstreaming_response(request, first_response, tools, user)
log.debug("smonux 20") log.debug("smonux 20")
return await handle_streaming_response(request, first_response, tools, data_items, call_next) return await handle_streaming_response(request, first_response, tools, data_items, call_next)