diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index aa963d434..56e967ae4 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -873,7 +873,7 @@ async def chat_completion( try: response = await chat_completion_handler(request, form_data, user) return await process_chat_response( - request, response, user, events, metadata, tasks + request, response, form_data, user, events, metadata, tasks ) except Exception as e: raise HTTPException( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e13730b39..c055eea21 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -51,6 +51,7 @@ from open_webui.utils.misc import ( get_message_list, add_or_update_system_message, get_last_user_message, + get_last_assistant_message, prepend_to_first_user_message_content, ) from open_webui.utils.tools import get_tools @@ -745,7 +746,9 @@ async def process_chat_payload(request, form_data, metadata, user, model): return form_data, events -async def process_chat_response(request, response, user, events, metadata, tasks): +async def process_chat_response( + request, response, form_data, user, events, metadata, tasks +): if not isinstance(response, StreamingResponse): return response @@ -790,7 +793,9 @@ async def process_chat_response(request, response, user, events, metadata, tasks }, ) - content = "" + assistant_message = get_last_assistant_message(form_data["messages"]) + content = assistant_message if assistant_message else "" + async for line in response.body_iterator: line = line.decode("utf-8") if isinstance(line, bytes) else line data = line diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 08abde0cc..a83733d63 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -68,6 +68,13 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]: return get_content_from_message(message) +def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]: + for message in reversed(messages): + if message["role"] == "assistant": + return message + return None + + def get_last_assistant_message(messages: list[dict]) -> Optional[str]: for message in reversed(messages): if message["role"] == "assistant":