From 89669a21fc4859e1a5dfb30778a7e3ce3791d0aa Mon Sep 17 00:00:00 2001 From: Xingjian Xie Date: Thu, 6 Feb 2025 23:01:43 +0000 Subject: [PATCH 1/2] Refactor common code between inlet and outlet --- backend/open_webui/utils/chat.py | 141 ++++++------------------- backend/open_webui/utils/filter.py | 100 ++++++++++++++++++ backend/open_webui/utils/middleware.py | 105 ++---------------- 3 files changed, 143 insertions(+), 203 deletions(-) create mode 100644 backend/open_webui/utils/filter.py diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 0719f6af5..ebd5bb5e3 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -44,6 +44,10 @@ from open_webui.utils.response import ( convert_response_ollama_to_openai, convert_streaming_response_ollama_to_openai, ) +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL @@ -177,116 +181,37 @@ async def chat_completed(request: Request, form_data: dict, user: Any): except Exception as e: return Exception(f"Error: {e}") - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) + metadata = { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - request.app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__request__": request, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - return Exception(f"Error: {e}") - - return data + try: + result, _ = await process_filter_functions( + handler_type="outlet", + filter_ids=get_sorted_filter_ids(model), + request=request, + data=data, + extra_params=extra_params, + ) + return result + except Exception as e: + return Exception(f"Error: {e}") async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py new file mode 100644 index 000000000..2ad0c025e --- /dev/null +++ b/backend/open_webui/utils/filter.py @@ -0,0 +1,100 @@ +import inspect +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.models.functions import Functions + +def get_sorted_filter_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] + filter_ids.sort(key=get_priority) + return filter_ids + +async def process_filter_functions( + handler_type, + filter_ids, + request, + data, + extra_params +): + skip_files = None + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if handler_type == "inlet" and hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + # Apply valves to the function + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + # Prepare handler function + handler = getattr(function_module, handler_type, None) + if not handler: + continue + + try: + # Prepare parameters + sig = inspect.signature(handler) + params = {"body": data} + + # Add extra parameters that exist in the handler's signature + for key in list(extra_params.keys()): + if key in sig.parameters: + params[key] = extra_params[key] + + # Handle user parameters + if "__user__" in sig.parameters: + if hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + + # Execute handler + if inspect.iscoroutinefunction(handler): + data = await handler(**params) + else: + data = handler(**params) + + except Exception as e: + print(f"Error in {handler_type} handler {filter_id}: {e}") + raise e + + # Handle file cleanup for inlet + if skip_files and "files" in data.get("metadata", {}): + del data["metadata"]["files"] + + return data, {} \ No newline at end of file diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 331b850ff..c69d0c909 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,6 +68,10 @@ from open_webui.utils.misc import ( ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.tasks import create_task @@ -91,99 +95,6 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def chat_completion_filter_functions_handler(request, body, model, extra_params): - skip_files = None - - def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [ - function.id for function in Functions.get_global_filter_functions() - ] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - request.app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - # Apply valves to the function - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if hasattr(function_module, "inlet"): - try: - inlet = function_module.inlet - - # Create a dictionary of parameters to be passed to the function - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in inspect.signature(inlet).parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - async def chat_completion_tools_handler( request: Request, body: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: @@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model): ) try: - form_data, flags = await chat_completion_filter_functions_handler( - request, form_data, model, extra_params + form_data, flags = await process_filter_functions( + handler_type="inlet", + filter_ids=get_sorted_filter_ids(model), + request=request, + data=form_data, + extra_params=extra_params, ) except Exception as e: raise Exception(f"Error: {e}") From 3dde2f67cfaa938b9b25cf549deb9249793835d8 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 7 Feb 2025 22:57:39 -0800 Subject: [PATCH 2/2] refac --- backend/open_webui/utils/chat.py | 6 +++--- backend/open_webui/utils/filter.py | 29 ++++++++++++-------------- backend/open_webui/utils/middleware.py | 23 +++++++++++++------- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index ebd5bb5e3..f0b52eca2 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -203,10 +203,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any): try: result, _ = await process_filter_functions( - handler_type="outlet", - filter_ids=get_sorted_filter_ids(model), request=request, - data=data, + filter_ids=get_sorted_filter_ids(model), + filter_type="outlet", + form_data=data, extra_params=extra_params, ) return result diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 2ad0c025e..88fe70353 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -2,6 +2,7 @@ import inspect from open_webui.utils.plugin import load_function_module_by_id from open_webui.models.functions import Functions + def get_sorted_filter_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) @@ -19,17 +20,14 @@ def get_sorted_filter_ids(model): function.id for function in Functions.get_functions_by_type("filter", active_only=True) ] - + filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] filter_ids.sort(key=get_priority) return filter_ids + async def process_filter_functions( - handler_type, - filter_ids, - request, - data, - extra_params + request, filter_ids, filter_type, form_data, extra_params ): skip_files = None @@ -45,7 +43,7 @@ async def process_filter_functions( request.app.state.FUNCTIONS[filter_id] = function_module # Check if the function has a file_handler variable - if handler_type == "inlet" and hasattr(function_module, "file_handler"): + if filter_type == "inlet" and hasattr(function_module, "file_handler"): skip_files = function_module.file_handler # Apply valves to the function @@ -56,14 +54,14 @@ async def process_filter_functions( ) # Prepare handler function - handler = getattr(function_module, handler_type, None) + handler = getattr(function_module, filter_type, None) if not handler: continue try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": data} + params = {"body": form_data} # Add extra parameters that exist in the handler's signature for key in list(extra_params.keys()): @@ -82,19 +80,18 @@ async def process_filter_functions( except Exception as e: print(e) - # Execute handler if inspect.iscoroutinefunction(handler): - data = await handler(**params) + form_data = await handler(**params) else: - data = handler(**params) + form_data = handler(**params) except Exception as e: - print(f"Error in {handler_type} handler {filter_id}: {e}") + print(f"Error in {filter_type} handler {filter_id}: {e}") raise e # Handle file cleanup for inlet - if skip_files and "files" in data.get("metadata", {}): - del data["metadata"]["files"] + if skip_files and "files" in form_data.get("metadata", {}): + del form_data["metadata"]["files"] - return data, {} \ No newline at end of file + return form_data, {} diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index c69d0c909..14d01221c 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -694,10 +694,10 @@ async def process_chat_payload(request, form_data, metadata, user, model): try: form_data, flags = await process_filter_functions( - handler_type="inlet", - filter_ids=get_sorted_filter_ids(model), request=request, - data=form_data, + filter_ids=get_sorted_filter_ids(model), + filter_type="inlet", + form_data=form_data, extra_params=extra_params, ) except Exception as e: @@ -1039,11 +1039,15 @@ async def process_chat_response( def split_content_and_whitespace(content): content_stripped = content.rstrip() - original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else '' + original_whitespace = ( + content[len(content_stripped) :] + if len(content) > len(content_stripped) + else "" + ) return content_stripped, original_whitespace def is_opening_code_block(content): - backtick_segments = content.split('```') + backtick_segments = content.split("```") # Even number of segments means the last backticks are opening a new block return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 @@ -1113,10 +1117,15 @@ async def process_chat_response( output = block.get("output", None) lang = attributes.get("lang", "") - content_stripped, original_whitespace = split_content_and_whitespace(content) + content_stripped, original_whitespace = ( + split_content_and_whitespace(content) + ) if is_opening_code_block(content_stripped): # Remove trailing backticks that would open a new block - content = content_stripped.rstrip('`').rstrip() + original_whitespace + content = ( + content_stripped.rstrip("`").rstrip() + + original_whitespace + ) else: # Keep content as is - either closing backticks or no backticks content = content_stripped + original_whitespace