From 3683334b2a1d9db44aced6540b4236b8d220c900 Mon Sep 17 00:00:00 2001 From: smonux Date: Sun, 13 Oct 2024 18:06:25 +0200 Subject: [PATCH] fix: citations for the streaming have to be send before the [END] --- backend/open_webui/main.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1b5e1e416..5f7ebbe69 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -459,7 +459,6 @@ async def handle_streaming_response(request: Request, response: Response, content_type = response.headers["Content-Type"] is_openai = "text/event-stream" in content_type is_ollama = "application/x-ndjson" in content_type - def wrap_item(item): return f"data: {item}\n\n" if is_openai else f"{item}\n" @@ -467,12 +466,16 @@ async def handle_streaming_response(request: Request, response: Response, for item in data_items: yield wrap_item(json.dumps(item)) + citations = [] body = json.loads(request._body) generator = original_generator try: while True: peek = await anext(generator) peek_json = extract_json(peek) + if peek == b'data: [DONE]\n' and len(citations) > 0 : + yield wrap_item(json.dumps({ "citations" : citations})) + if peek_json is None or not 'tool_calls' in peek_json['choices'][0]['delta']: yield peek continue @@ -520,14 +523,14 @@ async def handle_streaming_response(request: Request, response: Response, del body["metadata"]["files"] if tools[tool_function_name]["citation"]: - citation = { + log.debug("smonux CITATION") + citations.append( { "source": { "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" }, "document": [tool_output], "metadata": [{"source": tool_function_name}], - } - yield wrap_item(json.dumps(citation)) + }) # Append the tool output to the messages body["messages"].append({ "role": "tool", @@ -543,11 +546,13 @@ async def handle_streaming_response(request: Request, response: Response, # body_iterator here does not have __anext_() so it has to be done this way generator = response.body_iterator.__aiter__() - except StopAsyncIteration as sie: + except StopAsyncIteration: pass except Exception as e: log.exception(f"Error: {e}") + + return StreamingResponse( stream_wrapper(response.body_iterator, data_items), headers=dict(response.headers),