From 712f82e4ddb9f3361fab74dd7e9cc46852ea7175 Mon Sep 17 00:00:00 2001 From: smonux Date: Sat, 12 Oct 2024 20:01:43 +0200 Subject: [PATCH] fix: now streaming works --- backend/open_webui/main.py | 128 +++++++++++++++++++++++++------------ 1 file changed, 88 insertions(+), 40 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a99fc42ba..ec47668ea 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -11,7 +11,7 @@ import uuid import asyncio from contextlib import asynccontextmanager -from typing import Optional +from typing import Optional, Callable import aiohttp import requests @@ -432,58 +432,106 @@ def update_body_request(request: Request, ] return None +def extract_json(binary:bytes)->dict: + s = binary.decode("utf-8") + return json.loads(s[ s.find("{") : s.rfind("}") + 1 ]) + +def fill_with_delta(fcall_dict:dict, delta:dict) -> None: + if not 'delta' in delta['choices'][0]: + return + j = delta['choices'][0]['delta']['tool_calls'][0] + if 'id' in j: + fcall_dict["id"] += j["id"] + if 'function' in j: + if 'name' in j['function']: + fcall_dict['function']['name'] += j['function']['name'] + if 'arguments' in j['function']: + fcall_dict['function']['arguments'] += j['function']['arguments'] async def handle_streaming_response(request: Request, response: Response, tools: dict, data_items: list, - call_next) -> StreamingResponse: + call_next: Callable, + user : UserModel) -> StreamingResponse: - """content_type = response.headers["Content-Type"] + content_type = response.headers["Content-Type"] is_openai = "text/event-stream" in content_type is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response""" - log.debug("smonux 22") def wrap_item(item): - #return f"data: {item}\n\n" if is_openai else f"{item}\n" - return f"data: {item}\n\n" + return f"data: {item}\n\n" if is_openai else f"{item}\n" async def stream_wrapper(original_generator, data_items): - #for item in data_items: - # yield wrap_item(json.dumps(item)) - while True: - full_response = "" - async for data in original_generator: - full_response += data.decode('utf-8') if isinstance(data, bytes) else data - yield data - log.debug(f"smonux 24 {full_response}") + for item in data_items: + yield wrap_item(json.dumps(item)) + + body = json.loads(request._body) + generator = original_generator + try: + while True: + peek = await generator.__anext__() + if peek in (b'\n', b'data: [DONE]'): + yield peek + continue + peek_json = extract_json(peek) + if not 'tool_calls' in peek_json['choices'][0]['delta']: + yield peek + continue - full_response_dict = json.loads(full_response) - if full_response_dict["choices"][0]["finish_reason"] != "tool_calls": - break - - body["messages"].append(response_dict["choices"][0]["message"]) - for tool_call in full_response_dict["choices"][0].get("tool_calls", []): - tool_function_name = tool_call["function"]["name"] - tool_function_params = json.loads(tool_call["function"]["arguments"]) + # We reached a tool call we consume all the messages to assemble it + log.debug("async tool call detected") + tool_calls = [] # id, name, arguments + tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}}) + current_index = peek_json['choices'][0]['delta']['tool_calls'][0]['index'] + fill_with_delta(tool_calls[current_index], peek_json) - try: - log.debug(f"smonux 24 {tool_function_name}") - tool_output = await tools[tool_function_name]["callable"](**tool_function_params) - except Exception as e: + async for data in generator: + log.debug(f"smonux 24 {data=}") + if data in (b'\n', b'data: [DONE]\n'): + continue + delta = extract_json(data) + if delta['choices'][0]['finish_reason'] is not None: + continue + + i = delta['choices'][0]['delta']['tool_calls'][0]['index'] + if i != current_index: + tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}}) + current_index = i + + fill_with_delta(tool_calls[i], delta) + + log.debug(f"tools to call { [ t['function']['name'] for t in tool_calls] }") + + body["messages"].append( {'role': 'assistant', + 'content': None, + 'refusal': None, + 'tool_calls': tool_calls }) + + for tool_call in tool_calls: + tool_function_name = tool_call["function"]["name"] + tool_function_params = json.loads(tool_call["function"]["arguments"]) + + try: + tool_output = await tools[tool_function_name]["callable"](**tool_function_params) + except Exception as e: tool_output = str(e) - # Append the tool output to the messages - body["messages"].append({ - "role": "tool", - "tool_call_id" : tool_call["id"], - "name": tool_function_name, - "content": tool_output - }) + # Append the tool output to the messages + body["messages"].append({ + "role": "tool", + "tool_call_id" : tool_call["id"], + "name": tool_function_name, + "content": tool_output + }) + + # Make another request to the model with the updated context + log.debug("calling the model again with tool output included") update_body_request(request, body) - response = await call_next(request) - original_generator = response.body_iterator + response = await generate_chat_completions(form_data = body, user = user ) + generator = response.body_iterator.__aiter__() + + except Exception as e: + log.exception(f"Error: {e}") return StreamingResponse( stream_wrapper(response.body_iterator, data_items), @@ -491,7 +539,7 @@ async def handle_streaming_response(request: Request, response: Response, ) async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse: - # It only should be one response + # It only should be one response ince we are in the async scenario async for data in response.body_iterator: content = data response_dict = json.loads(content) @@ -786,8 +834,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if body.get("stream", False) is False: return await handle_nonstreaming_response(request, first_response, tools, user) - log.debug("smonux 20") - return await handle_streaming_response(request, first_response, tools, data_items, call_next) + + return await handle_streaming_response(request, first_response, tools, data_items, call_next, user) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False}