diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7379d063d..4919df643 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -433,8 +433,65 @@ def update_body_request(request: Request, return None +async def handle_streaming_response(request: Request, response: Response, + tools: dict, + data_items: list, + call_next) -> StreamingResponse: + log.debug(f"smonux 21 {response.headers}") + + """content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response""" + + log.debug("smonux 22") + def wrap_item(item): + #return f"data: {item}\n\n" if is_openai else f"{item}\n" + return f"data: {item}\n\n" + + async def stream_wrapper(original_generator, data_items): + #for item in data_items: + # yield wrap_item(json.dumps(item)) + while True: + full_response = "" + async for data in original_generator: + full_response += data.decode('utf-8') if isinstance(data, bytes) else data + yield data + log.debug(f"smonux 24 {full_response}") + + full_response_dict = json.loads(full_response[full_response.find("{"): full_response.rfind("}") + 1]) + if full_response_dict["choices"][0]["finish_reason"] != "tool_calls": + break + + for tool_call in full_response_dict["choices"][0].get("tool_calls", []): + tool_function_name = tool_call["function"]["name"] + tool_function_params = json.loads(tool_call["function"]["arguments"]) + + try: + log.debug(f"smonux 24 {tool_function_name}") + tool_output = await tools[tool_function_name]["callable"](**tool_function_params) + except Exception as e: + tool_output = str(e) + + # Append the tool output to the messages + body["messages"].append({ + "role": "tool", + "name": tool_function_name, + "content": tool_output, + "tool_call_id" : tool_call["id"] + }) + update_body_request(request, body) + response = await call_next(request) + original_generator = response.body_iterator + + return StreamingResponse( + stream_wrapper(response.body_iterator, data_items), + headers=dict(response.headers), + ) + async def handle_nonstreaming_response(request: Request, response: Response, tools: dict) -> Response: - response_dict = json.loads(response) + response_dict = json.loads(response.content) body = json.loads(request._body) while response_dict["choices"][0]["finish_reason"] == "tool_calls": @@ -721,30 +778,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): update_body_request(request, body) first_response = await call_next(request) - if not isinstance(first_response, StreamingResponse): - response = await handle_nonstreaming_response(request, first_response, tools) - return response - content_type = first_response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return first_response - - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" - - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) - - async for data in original_generator: - yield data - - return StreamingResponse( - stream_wrapper(first_response.body_iterator, data_items), - headers=dict(first_response.headers), - ) + #if body.get("stream", False) is False: + # return await handle_nonstreaming_response(request, first_response, tools) + log.debug("smonux 20") + return await handle_streaming_response(request, first_response, tools, data_items, call_next) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False}