This commit is contained in:
Timothy Jaeryang Baek 2024-12-24 23:45:21 -07:00
parent a2366a20ba
commit 0d7d6899b9
3 changed files with 15 additions and 3 deletions

View File

@ -873,7 +873,7 @@ async def chat_completion(
try: try:
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
return await process_chat_response( return await process_chat_response(
request, response, user, events, metadata, tasks request, response, form_data, user, events, metadata, tasks
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -51,6 +51,7 @@ from open_webui.utils.misc import (
get_message_list, get_message_list,
add_or_update_system_message, add_or_update_system_message,
get_last_user_message, get_last_user_message,
get_last_assistant_message,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
@ -745,7 +746,9 @@ async def process_chat_payload(request, form_data, metadata, user, model):
return form_data, events return form_data, events
async def process_chat_response(request, response, user, events, metadata, tasks): async def process_chat_response(
request, response, form_data, user, events, metadata, tasks
):
if not isinstance(response, StreamingResponse): if not isinstance(response, StreamingResponse):
return response return response
@ -790,7 +793,9 @@ async def process_chat_response(request, response, user, events, metadata, tasks
}, },
) )
content = "" assistant_message = get_last_assistant_message(form_data["messages"])
content = assistant_message if assistant_message else ""
async for line in response.body_iterator: async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line data = line

View File

@ -68,6 +68,13 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]:
return get_content_from_message(message) return get_content_from_message(message)
def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "assistant":
return message
return None
def get_last_assistant_message(messages: list[dict]) -> Optional[str]: def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
for message in reversed(messages): for message in reversed(messages):
if message["role"] == "assistant": if message["role"] == "assistant":