From 3a0a1aca1184f5c8c7a7cf3a7217a7980d80d4f6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 21 Sep 2024 01:07:57 +0200 Subject: [PATCH] refac: task ollama stream support --- backend/open_webui/main.py | 41 ++++++++++++++++++++++++---- backend/open_webui/utils/misc.py | 16 ++++++++--- backend/open_webui/utils/response.py | 24 ++++++++++++++-- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 55f9acfdf..a6141484f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -138,7 +138,10 @@ from open_webui.utils.utils import ( from open_webui.utils.webhook import post_webhook from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import convert_response_ollama_to_openai +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) if SAFE_MODE: print("SAFE MODE ENABLED") @@ -1470,7 +1473,14 @@ Prompt: {{prompt:middletruncate:8000}}""" payload = convert_payload_openai_to_ollama(payload) form_data = GenerateChatCompletionForm(**payload) response = await generate_ollama_chat_completion(form_data=form_data, user=user) - return convert_response_ollama_to_openai(response) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) else: return await generate_chat_completions(form_data=payload, user=user) @@ -1554,7 +1564,14 @@ Search Query:""" payload = convert_payload_openai_to_ollama(payload) form_data = GenerateChatCompletionForm(**payload) response = await generate_ollama_chat_completion(form_data=form_data, user=user) - return convert_response_ollama_to_openai(response) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) else: return await generate_chat_completions(form_data=payload, user=user) @@ -1629,7 +1646,14 @@ Message: """{{prompt}}""" payload = convert_payload_openai_to_ollama(payload) form_data = GenerateChatCompletionForm(**payload) response = await generate_ollama_chat_completion(form_data=form_data, user=user) - return convert_response_ollama_to_openai(response) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) else: return await generate_chat_completions(form_data=payload, user=user) @@ -1694,7 +1718,14 @@ Responses from models: {{responses}}""" payload = convert_payload_openai_to_ollama(payload) form_data = GenerateChatCompletionForm(**payload) response = await generate_ollama_chat_completion(form_data=form_data, user=user) - return convert_response_ollama_to_openai(response) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) else: return await generate_chat_completions(form_data=payload, user=user) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index d1b340044..bdce74b05 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -105,17 +105,25 @@ def openai_chat_message_template(model: str): } -def openai_chat_chunk_message_template(model: str, message: str) -> dict: +def openai_chat_chunk_message_template( + model: str, message: Optional[str] = None +) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" - template["choices"][0]["delta"] = {"content": message} + if message: + template["choices"][0]["delta"] = {"content": message} + else: + template["choices"][0]["finish_reason"] = "stop" return template -def openai_chat_completion_message_template(model: str, message: str) -> dict: +def openai_chat_completion_message_template( + model: str, message: Optional[str] = None +) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion" - template["choices"][0]["message"] = {"content": message, "role": "assistant"} + if message: + template["choices"][0]["message"] = {"content": message, "role": "assistant"} template["choices"][0]["finish_reason"] = "stop" return template diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 22275488f..3debe63af 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -1,10 +1,9 @@ -from open_webui.utils.task import prompt_template +import json from open_webui.utils.misc import ( + openai_chat_chunk_message_template, openai_chat_completion_message_template, ) -from typing import Callable, Optional - def convert_response_ollama_to_openai(ollama_response: dict) -> dict: model = ollama_response.get("model", "ollama") @@ -12,3 +11,22 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict: response = openai_chat_completion_message_template(model, message_content) return response + + +async def convert_streaming_response_ollama_to_openai(ollama_streaming_response): + async for data in ollama_streaming_response.body_iterator: + data = json.loads(data) + + model = data.get("model", "ollama") + message_content = data.get("message", {}).get("content", "") + done = data.get("done", False) + + data = openai_chat_chunk_message_template( + model, message_content if not done else None + ) + + line = f"data: {json.dumps(data)}\n\n" + if done: + line += "data: [DONE]\n\n" + + yield line