From 89669a21fc4859e1a5dfb30778a7e3ce3791d0aa Mon Sep 17 00:00:00 2001 From: Xingjian Xie Date: Thu, 6 Feb 2025 23:01:43 +0000 Subject: [PATCH] Refactor common code between inlet and outlet --- backend/open_webui/utils/chat.py | 141 ++++++------------------- backend/open_webui/utils/filter.py | 100 ++++++++++++++++++ backend/open_webui/utils/middleware.py | 105 ++---------------- 3 files changed, 143 insertions(+), 203 deletions(-) create mode 100644 backend/open_webui/utils/filter.py diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 0719f6af5..ebd5bb5e3 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -44,6 +44,10 @@ from open_webui.utils.response import ( convert_response_ollama_to_openai, convert_streaming_response_ollama_to_openai, ) +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL @@ -177,116 +181,37 @@ async def chat_completed(request: Request, form_data: dict, user: Any): except Exception as e: return Exception(f"Error: {e}") - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) + metadata = { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - request.app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__request__": request, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - return Exception(f"Error: {e}") - - return data + try: + result, _ = await process_filter_functions( + handler_type="outlet", + filter_ids=get_sorted_filter_ids(model), + request=request, + data=data, + extra_params=extra_params, + ) + return result + except Exception as e: + return Exception(f"Error: {e}") async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py new file mode 100644 index 000000000..2ad0c025e --- /dev/null +++ b/backend/open_webui/utils/filter.py @@ -0,0 +1,100 @@ +import inspect +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.models.functions import Functions + +def get_sorted_filter_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] + filter_ids.sort(key=get_priority) + return filter_ids + +async def process_filter_functions( + handler_type, + filter_ids, + request, + data, + extra_params +): + skip_files = None + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if handler_type == "inlet" and hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + # Apply valves to the function + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + # Prepare handler function + handler = getattr(function_module, handler_type, None) + if not handler: + continue + + try: + # Prepare parameters + sig = inspect.signature(handler) + params = {"body": data} + + # Add extra parameters that exist in the handler's signature + for key in list(extra_params.keys()): + if key in sig.parameters: + params[key] = extra_params[key] + + # Handle user parameters + if "__user__" in sig.parameters: + if hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + + # Execute handler + if inspect.iscoroutinefunction(handler): + data = await handler(**params) + else: + data = handler(**params) + + except Exception as e: + print(f"Error in {handler_type} handler {filter_id}: {e}") + raise e + + # Handle file cleanup for inlet + if skip_files and "files" in data.get("metadata", {}): + del data["metadata"]["files"] + + return data, {} \ No newline at end of file diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 331b850ff..c69d0c909 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,6 +68,10 @@ from open_webui.utils.misc import ( ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.tasks import create_task @@ -91,99 +95,6 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def chat_completion_filter_functions_handler(request, body, model, extra_params): - skip_files = None - - def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [ - function.id for function in Functions.get_global_filter_functions() - ] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - request.app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - # Apply valves to the function - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if hasattr(function_module, "inlet"): - try: - inlet = function_module.inlet - - # Create a dictionary of parameters to be passed to the function - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in inspect.signature(inlet).parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - async def chat_completion_tools_handler( request: Request, body: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: @@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model): ) try: - form_data, flags = await chat_completion_filter_functions_handler( - request, form_data, model, extra_params + form_data, flags = await process_filter_functions( + handler_type="inlet", + filter_ids=get_sorted_filter_ids(model), + request=request, + data=form_data, + extra_params=extra_params, ) except Exception as e: raise Exception(f"Error: {e}")