From 1f38350128a6b2ccc6d0c6fea2bcae17fe4e77c2 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 16 May 2025 23:33:02 +0400 Subject: [PATCH] feat: toggle filter middleware --- backend/open_webui/main.py | 1 + backend/open_webui/routers/tasks.py | 5 +--- backend/open_webui/utils/chat.py | 2 +- backend/open_webui/utils/filter.py | 37 +++++++++++++++++++------- backend/open_webui/utils/middleware.py | 9 +++++-- src/lib/components/chat/Chat.svelte | 1 + 6 files changed, 38 insertions(+), 17 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 646db6846..c277a8a98 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1186,6 +1186,7 @@ async def chat_completion( "chat_id": form_data.pop("chat_id", None), "message_id": form_data.pop("id", None), "session_id": form_data.pop("session_id", None), + "filter_ids": form_data.pop("filter_ids", None), "tool_ids": form_data.get("tool_ids", None), "tool_servers": form_data.pop("tool_servers", None), "files": form_data.get("files", None), diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 14a6c4286..8b34c8630 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS from open_webui.routers.pipelines import process_pipeline_inlet_filter -from open_webui.utils.filter import ( - get_sorted_filter_ids, - process_filter_functions, -) + from open_webui.utils.task import get_task_model_id from open_webui.config import ( diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index a6a06c522..f6a98b5f3 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -330,7 +330,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): try: filter_functions = [ Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids(model) + for filter_id in get_sorted_filter_ids(request, model) ] result, _ = await process_filter_functions( diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 76c9db9eb..02e504765 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -9,7 +9,20 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -def get_sorted_filter_ids(model: dict): +def get_function_module(request, function_id): + """ + Get the function module by its ID. + """ + if function_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module + + return function_module + + +def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None: @@ -21,14 +34,23 @@ def get_sorted_filter_ids(model: dict): if "info" in model and "meta" in model["info"]: filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ + active_filter_ids = [ function.id for function in Functions.get_functions_by_type("filter", active_only=True) ] - filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] + for filter_id in active_filter_ids: + function_module = get_function_module(request, filter_id) + + if getattr(function_module, "toggle", None) and ( + filter_id not in enabled_filter_ids + ): + active_filter_ids.remove(filter_id) + continue + + filter_ids = [fid for fid in filter_ids if fid in active_filter_ids] filter_ids.sort(key=get_priority) + return filter_ids @@ -43,12 +65,7 @@ async def process_filter_functions( if not filter: continue - if filter_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - request.app.state.FUNCTIONS[filter_id] = function_module - + function_module = get_function_module(request, filter_id) # Prepare handler function handler = getattr(function_module, filter_type, None) if not handler: diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 1b8b5c9bd..c9095f931 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -754,9 +754,12 @@ async def process_chat_payload(request, form_data, user, metadata, model): raise e try: + filter_functions = [ Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids(model) + for filter_id in get_sorted_filter_ids( + request, model, metadata.get("filter_ids", []) + ) ] form_data, flags = await process_filter_functions( @@ -1188,7 +1191,9 @@ async def process_chat_response( } filter_functions = [ Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids(model) + for filter_id in get_sorted_filter_ids( + request, model, metadata.get("filter_ids", []) + ) ] # Streaming response diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 5e5fadf50..e1a30ea8e 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1635,6 +1635,7 @@ }, files: (files?.length ?? 0) > 0 ? files : undefined, + filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_servers: $toolServers,