From 6777210d1f8eb3dd40c2c9c0bca8157a22a35a64 Mon Sep 17 00:00:00 2001 From: smonux Date: Wed, 9 Oct 2024 06:58:44 +0200 Subject: [PATCH] incremental fixes nonstreaming --- backend/open_webui/main.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 4919df643..942aff66b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -490,17 +490,25 @@ async def handle_streaming_response(request: Request, response: Response, headers=dict(response.headers), ) -async def handle_nonstreaming_response(request: Request, response: Response, tools: dict) -> Response: - response_dict = json.loads(response.content) +async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> Response: + + # It only should be one response + async for data in response.body_iterator: + content = data + response_dict = json.loads(content) body = json.loads(request._body) while response_dict["choices"][0]["finish_reason"] == "tool_calls": - for tool_call in response_dict["choices"][0]["tool_calls"]: + for tool_call in response_dict["choices"][0]["message"].get("tool_calls", []): + log.debug(f"smonux 12 {tool_call}") tool_function_name = tool_call["function"]["name"] tool_function_params = json.loads(tool_call["function"]["arguments"]) + log.debug(f"smonux 13 {tool_function_params=}") + try: tool_output = await tools[tool_function_name]["callable"](**tool_function_params) + log.debug(f"smonux 14 {tool_output=}") except Exception as e: tool_output = str(e) @@ -513,11 +521,14 @@ async def handle_nonstreaming_response(request: Request, response: Response, too # Make another request to the model with the updated context update_body_request(request, body) - response = await call_next(request) - response_dict = json.loads(response) + response = await generate_chat_completions(form_data = body, user = user ) +# response = await call_next(request) + async for data in response.body_iterator: + content = data + response_dict = json.loads(content) + import time; time.sleep(0.5) return response - async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict @@ -731,6 +742,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): } body["metadata"] = metadata + log.debug("smonux 00") body, tools = get_tools_body(body, user, extra_params) try: @@ -776,11 +788,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(citations) > 0: data_items.append({"citations": citations}) + log.debug("smonux 10") update_body_request(request, body) first_response = await call_next(request) - #if body.get("stream", False) is False: - # return await handle_nonstreaming_response(request, first_response, tools) + if body.get("stream", False) is False: + return await handle_nonstreaming_response(request, first_response, tools, user) log.debug("smonux 20") return await handle_streaming_response(request, first_response, tools, data_items, call_next)