diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index e538f761e..a6a06c522 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -328,9 +328,14 @@ async def chat_completed(request: Request, form_data: dict, user: Any): } try: + filter_functions = [ + Functions.get_function_by_id(filter_id) + for filter_id in get_sorted_filter_ids(model) + ] + result, _ = await process_filter_functions( request=request, - filter_ids=get_sorted_filter_ids(model), + filter_functions=filter_functions, filter_type="outlet", form_data=data, extra_params=extra_params, diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index aae3f8ac5..f8bac54fa 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -33,12 +33,13 @@ def get_sorted_filter_ids(model: dict): async def process_filter_functions( - request, filter_ids, filter_type, form_data, extra_params + request, filter_functions, filter_type, form_data, extra_params ): skip_files = None - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) + for function in filter_functions: + filter = function + filter_id = function.id if not filter: continue diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 13a24f0ff..289d887df 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -715,9 +715,14 @@ async def process_chat_payload(request, form_data, user, metadata, model): raise e try: + filter_functions = [ + Functions.get_function_by_id(filter_id) + for filter_id in get_sorted_filter_ids(model) + ] + form_data, flags = await process_filter_functions( request=request, - filter_ids=get_sorted_filter_ids(model), + filter_functions=filter_functions, filter_type="inlet", form_data=form_data, extra_params=extra_params, @@ -1071,9 +1076,12 @@ async def process_chat_response( "__request__": request, "__model__": model, } - filter_ids = get_sorted_filter_ids(model) + filter_functions = [ + Functions.get_function_by_id(filter_id) + for filter_id in get_sorted_filter_ids(model) + ] - print(f"{filter_ids=}") + print(f"{filter_functions=}") # Streaming response if event_emitter and event_caller: @@ -1480,7 +1488,7 @@ async def process_chat_response( data, _ = await process_filter_functions( request=request, - filter_ids=filter_ids, + filter_functions=filter_functions, filter_type="stream", form_data=data, extra_params=extra_params, @@ -2077,7 +2085,7 @@ async def process_chat_response( for event in events: event, _ = await process_filter_functions( request=request, - filter_ids=filter_ids, + filter_functions=filter_functions, filter_type="stream", form_data=event, extra_params=extra_params, @@ -2089,7 +2097,7 @@ async def process_chat_response( async for data in original_generator: data, _ = await process_filter_functions( request=request, - filter_ids=filter_ids, + filter_functions=filter_functions, filter_type="stream", form_data=data, extra_params=extra_params,