From d7088efe73c8d16422918a9e2c6d77d0f9902d4c Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 1 Mar 2025 06:56:24 -0800 Subject: [PATCH] fix: "stream" hook not working --- backend/open_webui/main.py | 6 +++--- backend/open_webui/utils/filter.py | 2 +- backend/open_webui/utils/middleware.py | 16 ++++++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f885cffe5..d7c8df7f4 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1021,7 +1021,7 @@ async def chat_completion( "files": form_data.get("files", None), "features": form_data.get("features", None), "variables": form_data.get("variables", None), - "model": model_info.model_dump() if model_info else model, + "model": model, "direct": model_item.get("direct", False), **( {"function_calling": "native"} @@ -1039,7 +1039,7 @@ async def chat_completion( form_data["metadata"] = metadata form_data, metadata, events = await process_chat_payload( - request, form_data, metadata, user, model + request, form_data, user, metadata, model ) except Exception as e: @@ -1053,7 +1053,7 @@ async def chat_completion( response = await chat_completion_handler(request, form_data, user) return await process_chat_response( - request, response, form_data, user, events, metadata, tasks + request, response, form_data, user, metadata, model, events, tasks ) except Exception as e: raise HTTPException( diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 0ca754ed8..aae3f8ac5 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -def get_sorted_filter_ids(model): +def get_sorted_filter_ids(model: dict): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index de2b9c468..b28a3cbda 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,7 +68,7 @@ from open_webui.utils.misc import ( get_last_user_message, get_last_assistant_message, prepend_to_first_user_message_content, - convert_logit_bias_input_to_json + convert_logit_bias_input_to_json, ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id @@ -613,14 +613,16 @@ def apply_params_to_form_data(form_data, model): form_data["reasoning_effort"] = params["reasoning_effort"] if "logit_bias" in params: try: - form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"])) + form_data["logit_bias"] = json.loads( + convert_logit_bias_input_to_json(params["logit_bias"]) + ) except Exception as e: print(f"Error parsing logit_bias: {e}") return form_data -async def process_chat_payload(request, form_data, metadata, user, model): +async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") @@ -862,7 +864,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): async def process_chat_response( - request, response, form_data, user, events, metadata, tasks + request, response, form_data, user, metadata, model, events, tasks ): async def background_tasks_handler(): message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) @@ -1067,9 +1069,11 @@ async def process_chat_response( }, "__metadata__": metadata, "__request__": request, - "__model__": metadata.get("model"), + "__model__": model, } - filter_ids = get_sorted_filter_ids(form_data.get("model")) + filter_ids = get_sorted_filter_ids(model) + + print(f"{filter_ids=}") # Streaming response if event_emitter and event_caller: