diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index ebd5bb5e3..f0b52eca2 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -203,10 +203,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any): try: result, _ = await process_filter_functions( - handler_type="outlet", - filter_ids=get_sorted_filter_ids(model), request=request, - data=data, + filter_ids=get_sorted_filter_ids(model), + filter_type="outlet", + form_data=data, extra_params=extra_params, ) return result diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 2ad0c025e..88fe70353 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -2,6 +2,7 @@ import inspect from open_webui.utils.plugin import load_function_module_by_id from open_webui.models.functions import Functions + def get_sorted_filter_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) @@ -19,17 +20,14 @@ def get_sorted_filter_ids(model): function.id for function in Functions.get_functions_by_type("filter", active_only=True) ] - + filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] filter_ids.sort(key=get_priority) return filter_ids + async def process_filter_functions( - handler_type, - filter_ids, - request, - data, - extra_params + request, filter_ids, filter_type, form_data, extra_params ): skip_files = None @@ -45,7 +43,7 @@ async def process_filter_functions( request.app.state.FUNCTIONS[filter_id] = function_module # Check if the function has a file_handler variable - if handler_type == "inlet" and hasattr(function_module, "file_handler"): + if filter_type == "inlet" and hasattr(function_module, "file_handler"): skip_files = function_module.file_handler # Apply valves to the function @@ -56,14 +54,14 @@ async def process_filter_functions( ) # Prepare handler function - handler = getattr(function_module, handler_type, None) + handler = getattr(function_module, filter_type, None) if not handler: continue try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": data} + params = {"body": form_data} # Add extra parameters that exist in the handler's signature for key in list(extra_params.keys()): @@ -82,19 +80,18 @@ async def process_filter_functions( except Exception as e: print(e) - # Execute handler if inspect.iscoroutinefunction(handler): - data = await handler(**params) + form_data = await handler(**params) else: - data = handler(**params) + form_data = handler(**params) except Exception as e: - print(f"Error in {handler_type} handler {filter_id}: {e}") + print(f"Error in {filter_type} handler {filter_id}: {e}") raise e # Handle file cleanup for inlet - if skip_files and "files" in data.get("metadata", {}): - del data["metadata"]["files"] + if skip_files and "files" in form_data.get("metadata", {}): + del form_data["metadata"]["files"] - return data, {} \ No newline at end of file + return form_data, {} diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index c69d0c909..14d01221c 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -694,10 +694,10 @@ async def process_chat_payload(request, form_data, metadata, user, model): try: form_data, flags = await process_filter_functions( - handler_type="inlet", - filter_ids=get_sorted_filter_ids(model), request=request, - data=form_data, + filter_ids=get_sorted_filter_ids(model), + filter_type="inlet", + form_data=form_data, extra_params=extra_params, ) except Exception as e: @@ -1039,11 +1039,15 @@ async def process_chat_response( def split_content_and_whitespace(content): content_stripped = content.rstrip() - original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else '' + original_whitespace = ( + content[len(content_stripped) :] + if len(content) > len(content_stripped) + else "" + ) return content_stripped, original_whitespace def is_opening_code_block(content): - backtick_segments = content.split('```') + backtick_segments = content.split("```") # Even number of segments means the last backticks are opening a new block return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 @@ -1113,10 +1117,15 @@ async def process_chat_response( output = block.get("output", None) lang = attributes.get("lang", "") - content_stripped, original_whitespace = split_content_and_whitespace(content) + content_stripped, original_whitespace = ( + split_content_and_whitespace(content) + ) if is_opening_code_block(content_stripped): # Remove trailing backticks that would open a new block - content = content_stripped.rstrip('`').rstrip() + original_whitespace + content = ( + content_stripped.rstrip("`").rstrip() + + original_whitespace + ) else: # Keep content as is - either closing backticks or no backticks content = content_stripped + original_whitespace