fix: "stream" hook not working

This commit is contained in:
Timothy Jaeryang Baek 2025-03-01 06:56:24 -08:00
parent 05c5e73304
commit d7088efe73
3 changed files with 14 additions and 10 deletions

View File

@ -1021,7 +1021,7 @@ async def chat_completion(
"files": form_data.get("files", None), "files": form_data.get("files", None),
"features": form_data.get("features", None), "features": form_data.get("features", None),
"variables": form_data.get("variables", None), "variables": form_data.get("variables", None),
"model": model_info.model_dump() if model_info else model, "model": model,
"direct": model_item.get("direct", False), "direct": model_item.get("direct", False),
**( **(
{"function_calling": "native"} {"function_calling": "native"}
@ -1039,7 +1039,7 @@ async def chat_completion(
form_data["metadata"] = metadata form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
request, form_data, metadata, user, model request, form_data, user, metadata, model
) )
except Exception as e: except Exception as e:
@ -1053,7 +1053,7 @@ async def chat_completion(
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
return await process_chat_response( return await process_chat_response(
request, response, form_data, user, events, metadata, tasks request, response, form_data, user, metadata, model, events, tasks
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model): def get_sorted_filter_ids(model: dict):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):

View File

@ -68,7 +68,7 @@ from open_webui.utils.misc import (
get_last_user_message, get_last_user_message,
get_last_assistant_message, get_last_assistant_message,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
convert_logit_bias_input_to_json convert_logit_bias_input_to_json,
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.plugin import load_function_module_by_id
@ -613,14 +613,16 @@ def apply_params_to_form_data(form_data, model):
form_data["reasoning_effort"] = params["reasoning_effort"] form_data["reasoning_effort"] = params["reasoning_effort"]
if "logit_bias" in params: if "logit_bias" in params:
try: try:
form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"])) form_data["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(params["logit_bias"])
)
except Exception as e: except Exception as e:
print(f"Error parsing logit_bias: {e}") print(f"Error parsing logit_bias: {e}")
return form_data return form_data
async def process_chat_payload(request, form_data, metadata, user, model): async def process_chat_payload(request, form_data, user, metadata, model):
form_data = apply_params_to_form_data(form_data, model) form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
@ -862,7 +864,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
async def process_chat_response( async def process_chat_response(
request, response, form_data, user, events, metadata, tasks request, response, form_data, user, metadata, model, events, tasks
): ):
async def background_tasks_handler(): async def background_tasks_handler():
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
@ -1067,9 +1069,11 @@ async def process_chat_response(
}, },
"__metadata__": metadata, "__metadata__": metadata,
"__request__": request, "__request__": request,
"__model__": metadata.get("model"), "__model__": model,
} }
filter_ids = get_sorted_filter_ids(form_data.get("model")) filter_ids = get_sorted_filter_ids(model)
print(f"{filter_ids=}")
# Streaming response # Streaming response
if event_emitter and event_caller: if event_emitter and event_caller: